import {
  INode,
  IRenderContext,
  NodeStyleBase,
  Point,
  SvgVisual,
  Visual,
} from '@ardoq/yfiles';
import {
  CONTEXT_HIGHLIGHT_PADDING,
  DATA_RENDER_HASH,
  NODE_HEIGHT_WIDTH_SELECTION as NODE_HEIGHT,
  NODE_WIDTH,
} from 'yfilesExtensions/styles/consts';

import getNodeBounds from 'yfilesExtensions/getNodeBounds';

import getNodeOutline from './getNodeOutlineClone';
import { isContextNode } from '../../utils';
import { createSvgElement } from '@ardoq/dom-utils';

const shouldSkipRerender = (node: INode) => {
  const { height, width } = node.layout;
  const padding = 2 * CONTEXT_HIGHLIGHT_PADDING;

  const isNodeContextNode = isContextNode(node);
  if (isNodeContextNode && height === NODE_HEIGHT) return true;
  if (isNodeContextNode && width === NODE_WIDTH) return true;
  if (!isNodeContextNode && height === NODE_HEIGHT + padding) return true;
  if (!isNodeContextNode && width === NODE_WIDTH + padding) return true;

  return false;
};

export default abstract class ArdoqNodeStyleBase extends NodeStyleBase {
  protected abstract getHash(node: INode): string;
  protected abstract render(container: SVGElement, node: INode): void;
  override createVisual(_: IRenderContext, node: INode) {
    const hash = this.getHash(node);
    const container = createSvgElement('g');
    container.setAttribute(DATA_RENDER_HASH, hash);
    this.render(container, node);
    SvgVisual.setTranslate(container, node.layout.x, node.layout.y);
    return new SvgVisual(container);
  }

  override updateVisual(_: IRenderContext, oldVisual: Visual, node: INode) {
    const container = (oldVisual as SvgVisual).svgElement;
    const oldHash = container.getAttribute(DATA_RENDER_HASH);
    const newHash = this.getHash(node);
    if (shouldSkipRerender(node)) return oldVisual;
    if (!oldHash || oldHash !== newHash) {
      container.setAttribute(DATA_RENDER_HASH, newHash);
      this.render(container, node);
    }
    SvgVisual.setTranslate(container, node.layout.x, node.layout.y);
    return oldVisual;
  }

  override getOutline(node: INode) {
    return getNodeOutline(node);
  }
  override getIntersection(node: INode, inner: Point, outer: Point) {
    // this method is called by the edge router when looking for a port to connect to this node.
    // the base implementation will find the intersection point using the result of getOutline(node).
    const straightDown = inner.y > outer.y && inner.x === outer.x;

    if (straightDown) {
      // because the base implementation uses getOutline() to find the intersection, it would connect downward pointing edges to the top side of the label bounds, which is ugly.
      // like this...

      //  ref  refs refs ref
      //   |   ↓ ↓ ↓ ↓ ↓  |
      //   |   _________  |
      //   |  |  node   | |
      //   ↓  | outline | ↓
      //  ____|         |_____
      // | long label outline |
      // |____________________|

      // we like that the label bounds are part of the outline when the refs come in upward or from the side, but downward is ugly.
      // the workaround is actually an optimization: we simply find the intersection point along the topside of the node.

      const nodeBounds = getNodeBounds(node, isContextNode(node));
      return new Point(
        Math.max(nodeBounds.x, Math.min(nodeBounds.maxX, outer.x)),
        nodeBounds.y
      );
    }
    return super.getIntersection(node, inner, outer);
  }
}
