import {useCallback, useEffect, useInsertionEffect, useRef, useState} from 'react'
import {z} from 'zod'

import camelCaseDeep from '../utils/camelCaseDeep'
import snakeCaseDeep from '../utils/snakeCaseDeep'

type WebSocketConnectionStatus = 'idle' | 'connecting' | 'open' | 'closed'

function exponentialBackoff(retryAttempt: number): number {
  return Math.min(1000 * 2 ** retryAttempt, 30000)
}

function useEventEffect<R, A extends unknown[]>(callback: (...args: A) => R): (...args: A) => R {
  const ref = useRef(callback)

  useInsertionEffect(() => {
    ref.current = callback
  }, [callback])

  return useCallback((...args) => ref.current(...args), [])
}

interface UseWebSocketOptions<C, S> {
  /** The schema for validating messages sent from the client. */
  clientMessageSchema: z.ZodType<C>
  /** The schema for validating messages received from the server. */
  serverMessageSchema: z.ZodType<S>
  /** The maximum number of times to retry connection. */
  maxRetries?: number
  /** Called when the message does not pass schema validation. */
  onValidationError?: (error: z.ZodError, message: unknown, isClientMessage: boolean) => void
  /** Called when the WebSocket connection is opened. */
  onOpen?: (sendMessage: (message: C) => void) => void
  /** Called when the WebSocket connection is closed. Returns whether reconnection should be attempted. */
  onClose?: (event: CloseEvent) => boolean
  /** Called when a message is received from the server. Returns whether reconnection should be attempted. */
  onMessage?: (message: S) => boolean
  /** The URL to connect to. */
  url: string
  /** Whether to establish the WebSocket connection. */
  shouldConnect?: boolean
}

interface UseWebSocketResult<C> {
  sendMessage: (message: C) => void
  status: WebSocketConnectionStatus
}

/**
 * Creates a WebSocket connection in which the server and client exchange stringified JSON
 * messages that are validated against the given schemas.
 */
export default function useTypedWebSocket<C, S>({
  clientMessageSchema,
  maxRetries = 10,
  onClose,
  onMessage,
  onOpen,
  onValidationError,
  serverMessageSchema,
  shouldConnect = true,
  url,
}: UseWebSocketOptions<C, S>): UseWebSocketResult<C> {
  const webSocketRef = useRef<WebSocket | null>(null)
  const [status, setStatus] = useState<WebSocketConnectionStatus>('idle')
  const [messageQueue, setMessageQueue] = useState<string[]>([])
  const [connectionAttempt, setConnectionAttempt] = useState(0)

  // close the connection if necessary
  if (!shouldConnect && status !== 'idle') {
    setStatus('idle')
    setConnectionAttempt(0)
  }

  // start connecting immediately if necessary
  if (shouldConnect && status === 'idle' && connectionAttempt === 0) {
    setStatus('connecting')
  }

  // discard the message queue if there is no active connection
  if (!shouldConnect && messageQueue.length > 0) {
    setMessageQueue([])
  }

  const onValidationErrorRef = useRef(onValidationError)
  useEffect(() => {
    onValidationErrorRef.current = onValidationError
  }, [onValidationError])

  const sendMessage = useCallback(
    (message: C): void => {
      if (!shouldConnect) return
      try {
        const ws = webSocketRef.current
        const messageString = JSON.stringify(snakeCaseDeep(clientMessageSchema.parse(message)))
        if (ws && ws.readyState === WebSocket.OPEN) ws.send(messageString)
        else setMessageQueue((prev) => [...prev, messageString])
      } catch (error) {
        if (error instanceof z.ZodError)
          onValidationErrorRef.current?.(error as z.ZodError, message, true)
      }
    },
    [clientMessageSchema, shouldConnect],
  )

  const handleOpen = useEventEffect((ws: WebSocket): void => {
    setStatus('open')
    setConnectionAttempt(0)
    onOpen?.(sendMessage)
    if (messageQueue.length > 0) {
      for (const message of messageQueue) ws.send(message)
      setMessageQueue([])
    }
  })

  const handleMessage = useEventEffect((event: MessageEvent): void => {
    try {
      const message = serverMessageSchema.parse(camelCaseDeep(JSON.parse(event.data)))
      const shouldReconnect = onMessage ? onMessage(message) : false
      if (shouldReconnect) setStatus('idle')
    } catch (error) {
      if (error instanceof z.ZodError)
        onValidationErrorRef.current?.(error as z.ZodError, event.data, false)
    }
  })

  const handleClose = useEventEffect((event: CloseEvent): void => {
    const shouldReconnect = onClose ? onClose(event) : false
    if (shouldReconnect && connectionAttempt < maxRetries) {
      setStatus('idle')
      setConnectionAttempt((prev) => prev + 1)
    } else {
      setStatus('closed')
      setConnectionAttempt(0)
    }
  })

  // wait for a backoff period before attempting to reconnect
  const isIdle = status === 'idle'
  useEffect(() => {
    if (!isIdle) return undefined
    const timeout = exponentialBackoff(connectionAttempt)
    const timeoutId = window.setTimeout(() => setStatus('connecting'), timeout)
    return () => window.clearTimeout(timeoutId)
  }, [connectionAttempt, isIdle])

  const isActive = status === 'connecting' || status === 'open'
  useEffect(() => {
    if (!isActive) return undefined
    const ws = new WebSocket(url)
    webSocketRef.current = ws

    const handleOpenImpl = (): void => handleOpen(ws)
    ws.addEventListener('open', handleOpenImpl)
    ws.addEventListener('message', handleMessage)
    ws.addEventListener('close', handleClose)

    return () => {
      ws.removeEventListener('open', handleOpenImpl)
      ws.removeEventListener('message', handleMessage)
      ws.removeEventListener('close', handleClose)
      ws.close()
      webSocketRef.current = null
    }
  }, [handleClose, handleMessage, handleOpen, isActive, url])

  return {sendMessage, status}
}
