import {useDevicePixelRatio} from '@kensho/tacklebox'
import {MutableRefObject, useEffect, useMemo, useState} from 'react'

import {getNeedsReviewHighlights} from '../highlights/getNeedsReviewHighlights'
import {KHighlight} from '../highlights/types'
import {APITranscript, TranscriptSelectionNode} from '../types/types'
import {indexFromPath, compareTokenPosition} from '../utils/transcriptUtils'

function useTranscriptHighlights(
  transcript: APITranscript | undefined,
  visibleBatches: Record<string, boolean>,
  transcriptRef: MutableRefObject<HTMLDivElement | null>,
  height: number,
  width: number,
  activeHighlightId: string | undefined,
): KHighlight[] {
  const [kHighlights, setKHighlights] = useState<KHighlight[]>([])
  const pixelRatio = useDevicePixelRatio()
  const visibleKeys = useMemo(() => Object.keys(visibleBatches).toString(), [visibleBatches])

  useEffect(() => {
    function getEleFromTranscriptSelection(
      transcriptSelectionNode: TranscriptSelectionNode | null,
      tokenElementsGroupedBySlice: Record<string, Record<number, Element>>,
    ): Element | null {
      if (!transcriptSelectionNode) return null
      const {type, sliceIndex, tokenIndex} = transcriptSelectionNode
      let ele = tokenElementsGroupedBySlice[sliceIndex]?.[tokenIndex] || null
      // if selection is on a token-space the tokenIndex refers to the preceding token so we need to get the next sibling
      if (type === 'token-space') {
        ele = ele?.nextSibling as Element
      }
      return ele
    }
    if (!transcript) {
      setKHighlights([])
    } else {
      const needsReviewHighlights = getNeedsReviewHighlights(transcript, activeHighlightId)
      const transcriptElement = transcriptRef.current
      if (!transcriptElement) return
      const allTokens = transcriptElement.querySelectorAll('[data-type="token"]')
      const tokenElementsGroupedBySlice: Record<string, Record<number, Element>> = {}
      const availableRange = {
        start: {sliceIndex: -1, tokenIndex: -1},
        end: {sliceIndex: -1, tokenIndex: -1},
      }

      const visibleSliceMap: Record<string, boolean> = {}
      visibleKeys.split(',').forEach((key) => {
        const sliceIndex = key.split('-')[0]
        visibleSliceMap[sliceIndex] = true
      })

      allTokens.forEach((token) => {
        const path = token.getAttribute('data-path')
        if (path) {
          const {sliceIndex, tokenIndex} = indexFromPath(path)
          // remove offscreen high LOD tokens from consideration
          if (!visibleSliceMap[sliceIndex]) return
          if (!tokenElementsGroupedBySlice[sliceIndex]) {
            tokenElementsGroupedBySlice[sliceIndex] = {}
          }
          tokenElementsGroupedBySlice[sliceIndex][tokenIndex] = token
          if (availableRange.start.sliceIndex === -1 && availableRange.start.tokenIndex === -1) {
            availableRange.start = {
              sliceIndex,
              tokenIndex,
            }
          }
          availableRange.end = {
            sliceIndex,
            tokenIndex,
          }
        }
      })
      const boundingRect = transcriptElement.getBoundingClientRect()
      const withSegments: KHighlight[] = needsReviewHighlights.map((kHighlight) => ({
        ...kHighlight,
        ranges: kHighlight.ranges.map((range) => {
          const clientRects: DOMRect[] = []

          if (!range.start || !range.end) return range
          const start = {...range.start}
          const end = {...range.end}
          if (
            compareTokenPosition(end, availableRange.start) === 'before' ||
            compareTokenPosition(start, availableRange.end) === 'after'
          ) {
            return range
          }
          if (compareTokenPosition(start, availableRange.start) === 'before') {
            start.sliceIndex = availableRange.start.sliceIndex
            start.tokenIndex = availableRange.start.tokenIndex
          }
          if (compareTokenPosition(end, availableRange.end) === 'after') {
            end.sliceIndex = availableRange.end.sliceIndex
            end.tokenIndex = availableRange.end.tokenIndex
            end.textOffset =
              getEleFromTranscriptSelection(end, tokenElementsGroupedBySlice)?.textContent
                ?.length ?? 0
          }
          const startElement = getEleFromTranscriptSelection(start, tokenElementsGroupedBySlice)
          const endElement = getEleFromTranscriptSelection(end, tokenElementsGroupedBySlice)

          if (!startElement || !endElement) return range
          // Build ranges and get client rects
          const webRange = document.createRange()
          // If the range is a part of a single slice, get the rectangles and we're done
          if (range.type === 'Caret') {
            webRange.setStart(startElement.childNodes[0], start.textOffset)
            webRange.setEnd(endElement.childNodes[0], end.textOffset)
            clientRects.push(...webRange.getClientRects())
          } else if (start.sliceIndex === end.sliceIndex) {
            webRange.setStart(startElement.childNodes[0], start.textOffset)
            webRange.setEnd(endElement.childNodes[0], end.textOffset)
            clientRects.push(...webRange.getClientRects())
          } else {
            // otherwise get the rects for part of the first slice
            const startSliceTokens = tokenElementsGroupedBySlice[start.sliceIndex]
            const startSliceLastIndex = Math.max(...Object.keys(startSliceTokens).map(Number))
            const endOfStartSliceToken = startSliceTokens[startSliceLastIndex]
            const endOfStartSliceElement = endOfStartSliceToken.nextElementSibling
              ? endOfStartSliceToken.nextElementSibling
              : endOfStartSliceToken
            webRange.setStart(startElement.childNodes[0], start.textOffset)
            webRange.setEndAfter(endOfStartSliceElement)
            clientRects.push(...webRange.getClientRects())
            // all the slices in between
            for (let i = start.sliceIndex + 1; i < end.sliceIndex; i += 1) {
              const slice = tokenElementsGroupedBySlice[i]
              if (slice) {
                const sliceTokenIndices = Object.keys(slice)
                  .map(Number)
                  .sort((a, b) => a - b)
                const startOfSliceToken = slice[sliceTokenIndices[0]]
                const endOfSliceToken = slice[sliceTokenIndices[sliceTokenIndices.length - 1]]
                const endOfSliceElement = endOfSliceToken.nextElementSibling
                  ? endOfSliceToken.nextElementSibling
                  : endOfSliceToken
                webRange.setStartBefore(startOfSliceToken)
                webRange.setEndAfter(endOfSliceElement)
                clientRects.push(...webRange.getClientRects())
              }
            }
            // and part of the end slice.
            const startOfEndSliceToken = tokenElementsGroupedBySlice[end.sliceIndex][0]
            webRange.setStart(startOfEndSliceToken, 0)
            webRange.setEnd(endElement.childNodes[0], end.textOffset)
            clientRects.push(...webRange.getClientRects())
            // We need to build the rects like this otherwise the slice headers are selected
          }

          const rectsByYAxis = new Map<number, DOMRect[]>()

          clientRects.forEach((rect) => {
            const rectsAtY = rectsByYAxis.get(rect.y) ?? []
            rectsAtY.push(rect)
            rectsByYAxis.set(rect.y, rectsAtY)
          })

          // Combine all the rectangles on the same y position (or line) into a single rectangle
          const adjustedRects = Array.from(rectsByYAxis.values()).map((rects) => {
            // At this point we have the client rectangles for the token elements and their child text nodes,
            // The token elements have a border which increases their height so get the smaller x and height to "filter" them
            const x = Math.min(...rects.map((rect) => rect.x))
            const minHeight = Math.min(...rects.map((rect) => rect.height))
            const {y} = rects[0]
            const combinedWidth = rects[rects.length - 1].right - rects[0].left

            return {
              x: x - boundingRect.left,
              y: y - boundingRect.top,
              height: minHeight,
              // TODO: Improve handling of Caret highlights
              width: combinedWidth === 0 ? 4 : combinedWidth,
            }
          })
          return {
            ...range,
            segments: adjustedRects,
          }
        }),
      }))
      setKHighlights(withSegments)
    }
  }, [transcript, pixelRatio, visibleKeys, transcriptRef, height, width, activeHighlightId])

  return kHighlights
}

export default useTranscriptHighlights
