import {useCallback, useEffect, useInsertionEffect, useRef, useState} from 'react'
import {z} from 'zod'

import camelCaseDeep from '../utils/camelCaseDeep'
import snakeCaseDeep from '../utils/snakeCaseDeep'

function useEventEffect<R, A extends unknown[]>(callback: (...args: A) => R): (...args: A) => R {
  const ref = useRef(callback)

  useInsertionEffect(() => {
    ref.current = callback
  }, [callback])

  return useCallback((...args) => ref.current(...args), [])
}

interface UseWebSocketOptions<C, S> {
  /** The schema for validating messages sent from the client. */
  clientMessageSchema: z.ZodType<C>
  /** The schema for validating messages received from the server. */
  serverMessageSchema: z.ZodType<S>
  /** Called when the WebSocket connection is opened. */
  onOpen?: (sendMessage: (message: C) => void) => void
  /** Called when the WebSocket connection is closed. */
  onClose?: (event: CloseEvent) => void
  /** Called when a message is received from the server. */
  onMessage?: (message: S) => void
  /** The URL to connect to. */
  url: string
  /** Whether to establish the WebSocket connection. */
  shouldConnect?: boolean
}

interface UseWebSocketResult<C> {
  sendMessage: (message: C) => void
}

/**
 * Creates a WebSocket connection in which the server and client exchange stringified JSON
 * messages that are validated against the given schemas.
 */
export default function useTypedWebSocket<C, S>({
  clientMessageSchema,
  onClose,
  onMessage,
  onOpen,
  serverMessageSchema,
  shouldConnect = true,
  url,
}: UseWebSocketOptions<C, S>): UseWebSocketResult<C> {
  const webSocketRef = useRef<WebSocket | null>(null)
  const [messageQueue, setMessageQueue] = useState<string[]>([])

  // discard the message queue if there is no active connection
  if (!shouldConnect && messageQueue.length > 0) {
    setMessageQueue([])
  }

  const sendMessage = useCallback(
    (message: C): void => {
      if (!shouldConnect) return
      const ws = webSocketRef.current
      const messageString = JSON.stringify(snakeCaseDeep(clientMessageSchema.parse(message)))
      if (ws && ws.readyState === WebSocket.OPEN) ws.send(messageString)
      else setMessageQueue((prev) => [...prev, messageString])
    },
    [clientMessageSchema, shouldConnect],
  )

  const handleOpen = useEventEffect((ws: WebSocket): void => {
    onOpen?.(sendMessage)
    if (messageQueue.length > 0) {
      for (const message of messageQueue) ws.send(message)
      setMessageQueue([])
    }
  })

  const handleMessage = useEventEffect((event: MessageEvent): void => {
    const message = serverMessageSchema.parse(camelCaseDeep(JSON.parse(event.data)))
    onMessage?.(message)
  })

  const handleClose = useEventEffect((event: CloseEvent): void => {
    onClose?.(event)
  })

  useEffect(() => {
    if (!shouldConnect) return undefined
    const ws = new WebSocket(url)
    webSocketRef.current = ws

    const handleOpenImpl = (): void => handleOpen(ws)
    ws.addEventListener('open', handleOpenImpl)
    ws.addEventListener('message', handleMessage)
    ws.addEventListener('close', handleClose)

    return () => {
      ws.removeEventListener('open', handleOpenImpl)
      ws.removeEventListener('message', handleMessage)
      ws.removeEventListener('close', handleClose)
      ws.close()
      webSocketRef.current = null
    }
  }, [handleClose, handleMessage, handleOpen, shouldConnect, url])

  return {sendMessage}
}
