import { FormControl, InputLabel, Select, useTheme } from '@mui/material'
import { RichTreeView } from '@mui/x-tree-view'
import { useTreeViewApiRef } from '@mui/x-tree-view/hooks'
import { TreeViewBaseItem } from '@mui/x-tree-view/models'
import { CSSProperties, FC, SyntheticEvent, startTransition, useMemo, useRef } from 'react'
import { ThemeMode } from 'types/config'
import { Period, parsePeriodDate } from 'types/periods'

interface MultiPeriodSelectorProps {
  periods: Period[]
  selectedPeriods: string[]
  className?: string
  style?: CSSProperties
  width?: string | number
  handleChange: (selectedPeriods: string[]) => void
  label?: string
  rootPeriodOnly?: boolean
  sortOrder?: 'asc' | 'desc'
  timeFilter?: 'pastOnly' | 'futureOnly' | 'standard'
  disabled?: boolean
  disableMultiRootSelection?: boolean
  showLabelAbove?: boolean
  space?: number
  variant?: 'standard' | 'outlined'
  shrinkLabel?: boolean
  placeholder?: string
  labelColor?: string
  labelBackgroundColor?: string
  disableClosed?: boolean
}

const MultiPeriodSelector: FC<MultiPeriodSelectorProps> = ({
  periods,
  selectedPeriods,
  className,
  style,
  width = '100%',
  handleChange,
  label = 'Filter Periods',
  rootPeriodOnly = false,
  sortOrder = 'desc',
  timeFilter = 'standard',
  disabled = false,
  disableMultiRootSelection = false,
  showLabelAbove = false,
  space = 0,
  variant = 'standard',
  shrinkLabel = false,
  placeholder = 'Selected Periods',
  labelColor = undefined,
  labelBackgroundColor = undefined,
  disableClosed = false,
}) => {
  const theme = useTheme()
  const apiRef = useTreeViewApiRef()
  const toggledItemRef = useRef<{ [itemId: string]: boolean }>({})

  const handleItemSelectionToggle = (event: SyntheticEvent, itemId: string, isSelected: boolean) => {
    startTransition(() => {
      toggledItemRef.current[itemId] = isSelected
    })
  }

  const handleSelectedItemsChange = (event: SyntheticEvent, newSelectedItems: string[]) => {
    startTransition(() => {
      const itemsToSelect: string[] = []
      const itemsToUnSelect: { [itemId: string]: boolean } = {}
      Object.entries(toggledItemRef.current).forEach(([itemId, isSelected]) => {
        const item = apiRef.current!.getItem(itemId)
        if (isSelected) {
          itemsToSelect.push(...getItemDescendantsIds(item))
        } else {
          getItemDescendantsIds(item).forEach((descendantId) => {
            itemsToUnSelect[descendantId] = true
          })
        }
      })

      const newSelectedItemsWithChildren = Array.from(
        new Set([...newSelectedItems, ...itemsToSelect].filter((itemId) => !itemsToUnSelect[itemId]))
      )

      handleChange(newSelectedItemsWithChildren)
      toggledItemRef.current = {}
    })
  }

  const getItemDescendantsIds = (item: TreeViewBaseItem) => {
    const ids: string[] = []
    item.children?.forEach((child) => {
      if (!isItemDisabled(child)) {
        ids.push(child.id)
        ids.push(...getItemDescendantsIds(child))
      }
    })
    return ids
  }

  const filteredPeriods = useMemo(() => {
    const now = new Date()
    const filterPeriod = (period: Period): Period | null => {
      if (disableClosed && period.isClosed && period.rootPeriodId !== null && period.rootPeriodId === period.id) {
        return null
      }

      const startDate = parsePeriodDate(period.startDate)
      const endDate = parsePeriodDate(period.endDate)

      if (!startDate || !endDate) {
        return period // Include periods with invalid dates
      }

      let result: boolean
      switch (timeFilter) {
        case 'pastOnly':
          result = startDate <= now
          break
        case 'futureOnly':
          result = startDate > now
          break
        default:
          result = true
      }

      if (!result) return null

      if (period.children && !rootPeriodOnly) {
        const filteredChildren = period.children.map(filterPeriod).filter((child): child is Period => child !== null)
        return { ...period, children: filteredChildren }
      }

      return period
    }

    return periods.map(filterPeriod).filter((period): period is Period => period !== null)
  }, [periods, timeFilter, rootPeriodOnly, disableClosed])

  const sortedPeriods = useMemo(() => {
    return [...filteredPeriods].sort((a, b) => {
      const dateA = new Date(a.endDate)
      const dateB = new Date(b.endDate)
      return sortOrder === 'asc' ? dateA.getTime() - dateB.getTime() : dateB.getTime() - dateA.getTime()
    })
  }, [filteredPeriods, sortOrder])

  const renderTree = (nodes: Period[] = []): TreeViewBaseItem[] =>
    nodes
      .filter((node) => node.id)
      .map((node) => ({
        id: node.id!,
        label: node.label,
        children: rootPeriodOnly ? [] : renderTree(node.children),
      }))

  const isItemDisabled = (item: TreeViewBaseItem): boolean => {
    if (disableClosed) {
      const period = findPeriodById(sortedPeriods, item.id)
      if (period?.isClosed) {
        return true
      }
    }

    if (disableMultiRootSelection) {
      const rootAndDepth = findRootIdAndDepth(renderTree(sortedPeriods), item.id)

      if (selectedPeriods.length < 1) return false

      const selectedRootAndDepth = findRootIdAndDepth(renderTree(sortedPeriods), selectedPeriods[0])
      if (rootAndDepth?.rootId !== selectedRootAndDepth?.rootId) return true
    }

    return false
  }

  const hoverColor =
    theme.palette.mode === ThemeMode.DARK ? theme.palette.primary.darker : theme.palette.primary.lighter

  const findRootIdAndDepth = (
    treeData: TreeViewBaseItem[],
    nodeId: string,
    depth: number = 0,
    rootId: string | null = null
  ): { rootId: string; depth: number } | null => {
    for (const treeNode of treeData) {
      if (!rootId || depth === 0) {
        rootId = treeNode.id
      }
      if (treeNode.id === nodeId) {
        return { rootId: rootId, depth: depth }
      }
      if (treeNode.children) {
        const result = findRootIdAndDepth(treeNode.children, nodeId, depth + 1, rootId)
        if (result) {
          return result // Return the root item at the current level
        }
      }
    }
    return null // Return null if the item ID is not found in the tree
  }

  const findPeriodById = (periods: Period[], id: string): Period | undefined => {
    for (const period of periods) {
      if (period.id === id) {
        return period
      }
      if (period.children) {
        const found = findPeriodById(period.children, id)
        if (found) return found
      }
    }
    return undefined
  }

  return (
    <div className={className} style={{ ...style, marginBottom: 0, paddingBottom: 0, width: width }}>
      <FormControl variant="standard" sx={{ width: '100%' }}>
        <InputLabel
          shrink
          variant="outlined"
          sx={{
            '&.MuiInputLabel-shrink': {
              background: labelBackgroundColor ? labelBackgroundColor : undefined,
              color: labelColor ? labelColor : undefined,
            },
          }}
        >
          {label}
        </InputLabel>
        <Select
          value={selectedPeriods}
          displayEmpty
          multiple={true}
          renderValue={() => (selectedPeriods.length > 0 ? `${label} (${selectedPeriods.length})` : '')}
          variant="outlined"
          style={{ width: '100%', borderRadius: '21px' }}
          disabled={disabled}
        >
          <RichTreeView
            multiSelect
            checkboxSelection
            style={{ width: '100%' }}
            apiRef={apiRef}
            items={renderTree(sortedPeriods)}
            isItemDisabled={isItemDisabled}
            selectedItems={selectedPeriods}
            onSelectedItemsChange={handleSelectedItemsChange}
            onItemSelectionToggle={handleItemSelectionToggle}
            sx={{
              '& .MuiTreeItem-content:hover': {
                backgroundColor: hoverColor,
              },
              '& .MuiTreeItem-content.Mui-selected:hover': {
                backgroundColor: theme.palette.action.selected,
              },
            }}
          />
        </Select>
      </FormControl>
    </div>
  )
}

export default MultiPeriodSelector
