import {useMemo, useEffect, useState, MutableRefObject} from 'react'
import {useDevicePixelRatio} from '@kensho/tacklebox'

import {APITranscript, TranscriptSelectionNode} from '../types/types'
import {indexFromPath} from '../utils/transcriptUtils'

import {AnnotationsLayerBlock} from './AnnotationsLayerBlock'
import {KHighlight, StyledSegments} from './types'
import {getPersistedHighlights} from './getPersistedHighlights'
import {FULL_TOKEN_OFFSET} from './constants'

function compare(
  a: {sliceIndex: number; tokenIndex: number},
  b: {sliceIndex: number; tokenIndex: number},
): 1 | -1 | 0 {
  if (a.sliceIndex < b.sliceIndex) {
    return -1
  }
  if (a.sliceIndex > b.sliceIndex) {
    return 1
  }
  if (a.tokenIndex < b.tokenIndex) {
    return -1
  }
  if (a.tokenIndex > b.tokenIndex) {
    return 1
  }
  return 0
}

export function AnnotationsLayer(props: {
  transcript: APITranscript
  height: number
  width: number
  visibleBatches: Record<string, boolean>
  transcriptRef: MutableRefObject<HTMLDivElement | null>
}): React.ReactNode {
  const {height, width, transcript, visibleBatches, transcriptRef} = props
  const [kHighlights, setKHighlights] = useState<KHighlight[]>([])
  const visibleKeys = useMemo(() => Object.keys(visibleBatches).toString(), [visibleBatches])
  const [styledSegments, setStyledSegments] = useState<StyledSegments[] | undefined>()
  const pixelRatio = useDevicePixelRatio()
  useEffect(() => {
    setKHighlights(getPersistedHighlights(transcript))
  }, [transcript])

  useEffect(() => {
    const transcriptElement = transcriptRef.current
    if (!transcriptElement) return
    const allTokens = transcriptElement.querySelectorAll('[data-type="token"]')

    const tokenElementsGroupedBySlice: Record<string, Element[]> = {}
    const availableRange = {
      start: {sliceIndex: -1, tokenIndex: -1},
      end: {sliceIndex: -1, tokenIndex: -1},
    }

    const visibleSliceMap: Record<string, boolean> = {}
    visibleKeys.split(',').forEach((key) => {
      const sliceIndex = key.split('-')[0]
      visibleSliceMap[sliceIndex] = true
    })

    allTokens.forEach((token) => {
      const path = token.getAttribute('data-path')
      if (path) {
        const {sliceIndex} = indexFromPath(path)
        // remove offscreen high LOD tokens from consideration
        if (!visibleSliceMap[sliceIndex]) return
        if (!tokenElementsGroupedBySlice[sliceIndex]) {
          tokenElementsGroupedBySlice[sliceIndex] = []
        }
        tokenElementsGroupedBySlice[sliceIndex].push(token)
        if (availableRange.start.sliceIndex === -1 && availableRange.start.tokenIndex === -1) {
          availableRange.start = {
            sliceIndex,
            tokenIndex: tokenElementsGroupedBySlice[sliceIndex].length - 1,
          }
        }
        availableRange.end = {
          sliceIndex,
          tokenIndex: tokenElementsGroupedBySlice[sliceIndex].length - 1,
        }
      }
    })

    function getEleFromTranscriptSelection(
      transcriptSelectionNode: TranscriptSelectionNode | null,
    ): Element | null {
      if (!transcriptSelectionNode) return null
      const {type, sliceIndex, tokenIndex} = transcriptSelectionNode
      let ele = tokenElementsGroupedBySlice[sliceIndex]?.[tokenIndex] || null
      // if selection is on a token-space the tokenIndex refers to the preceding token so we need to get the next sibling
      if (type === 'token-space') {
        ele = ele?.nextSibling as Element
      }
      return ele
    }

    const boundingRect = transcriptElement.getBoundingClientRect()
    const nextStyledSegments: StyledSegments[] = []
    kHighlights.forEach((kHighlight) => {
      kHighlight.ranges.forEach((range) => {
        const clientRects: DOMRect[] = []

        if (!range.start || !range.end) return
        const start = {...range.start}
        const end = {...range.end}

        if (compare(end, availableRange.start) === -1 || compare(start, availableRange.end) === 1) {
          return
        }

        if (compare(start, availableRange.start) === -1) {
          start.sliceIndex = availableRange.start.sliceIndex
          start.tokenIndex = availableRange.start.tokenIndex
        }
        if (compare(end, availableRange.end) === 1) {
          end.sliceIndex = availableRange.end.sliceIndex
          end.tokenIndex = availableRange.end.tokenIndex
        }
        const startElement = getEleFromTranscriptSelection(start)
        const endElement = getEleFromTranscriptSelection(end)

        if (!startElement || !endElement) return
        // Build ranges and get client rects
        const webRange = document.createRange()
        // If the range is a part of a single slice, get the rectangles and we're done
        if (range.type === 'Caret') {
          webRange.setStart(startElement.childNodes[0], start.textOffset)
          webRange.setEnd(endElement.childNodes[0], end.textOffset)
          clientRects.push(...webRange.getClientRects())
        } else if (start.sliceIndex === end.sliceIndex) {
          webRange.setStart(startElement.childNodes[0], start.textOffset)
          if (end.textOffset === FULL_TOKEN_OFFSET) {
            webRange.setEndAfter(endElement.childNodes[0])
          } else {
            webRange.setEnd(endElement.childNodes[0], end.textOffset)
          }
          clientRects.push(...webRange.getClientRects())
        } else {
          // otherwise get the rects for part of the first slice
          const endOfStartSliceToken =
            tokenElementsGroupedBySlice[start.sliceIndex][
              tokenElementsGroupedBySlice[start.sliceIndex].length - 1
            ]
          const endOfStartSliceElement = endOfStartSliceToken.nextElementSibling
            ? endOfStartSliceToken.nextElementSibling
            : endOfStartSliceToken
          webRange.setStart(startElement.childNodes[0], start.textOffset)
          webRange.setEndAfter(endOfStartSliceElement)
          clientRects.push(...webRange.getClientRects())
          // all the slices in between
          for (let i = start.sliceIndex + 1; i < end.sliceIndex; i += 1) {
            const slice = tokenElementsGroupedBySlice[i]
            if (slice) {
              const startofSliceToken = slice[0]
              const endOfSliceToken = slice[slice.length - 1]
              const endOfSliceElement = endOfSliceToken.nextElementSibling
                ? endOfSliceToken.nextElementSibling
                : endOfSliceToken
              webRange.setStartBefore(startofSliceToken)
              webRange.setEndAfter(endOfSliceElement)
              clientRects.push(...webRange.getClientRects())
            }
          }
          // and part of the end slice.
          const startOfEndSliceToken = tokenElementsGroupedBySlice[end.sliceIndex][0]
          webRange.setStart(startOfEndSliceToken, 0)
          if (end.textOffset === FULL_TOKEN_OFFSET) {
            webRange.setEndAfter(endElement.childNodes[0])
          } else {
            webRange.setEnd(endElement.childNodes[0], end.textOffset)
          }
          clientRects.push(...webRange.getClientRects())
          // We need to build the rects like this otherwise the slice headers are selected
        }

        const rectsByYAxis = new Map<number, DOMRect[]>()

        clientRects.forEach((rect) => {
          const rectsAtY = rectsByYAxis.get(rect.y) ?? []
          rectsAtY.push(rect)
          rectsByYAxis.set(rect.y, rectsAtY)
        })

        // Combine all the rectangles on the same y position (or line) into a single rectangle
        const adjustedRects = Array.from(rectsByYAxis.values()).map((rects) => {
          // At this point we have the client rectangles for the token elements and their child text nodes,
          // The token elements have a border which increases their height so get the smaller x and height to "filter" them
          const x = Math.min(...rects.map((rect) => rect.x))
          const minHeight = Math.min(...rects.map((rect) => rect.height))
          const {y} = rects[0]
          const combinedWidth = rects[rects.length - 1].right - rects[0].left

          return {
            x: x - boundingRect.left,
            y: y - boundingRect.top,
            height: minHeight,
            // TODO: Improve handling of Caret highlights
            width: combinedWidth === 0 ? 4 : combinedWidth,
          }
        })

        if (adjustedRects.length > 0) {
          nextStyledSegments.push({segmentGroup: adjustedRects, style: kHighlight.style})
        }
      })
    })

    setStyledSegments(nextStyledSegments)
  }, [kHighlights, pixelRatio, visibleKeys, transcriptRef, height, width])

  const blocks = []
  const CANVAS_BLOCK_HEIGHT = 512
  for (let i = 0; i < height; i += CANVAS_BLOCK_HEIGHT) {
    const remainingPixels = height - i
    blocks.push(
      <AnnotationsLayerBlock
        key={i}
        styledSegments={styledSegments || []}
        height={Math.min(CANVAS_BLOCK_HEIGHT, remainingPixels)}
        pixelRatio={pixelRatio}
        width={width}
        x={0}
        y={i}
      />,
    )
  }
  return blocks
}
