import React, { ElementRef, FC, ForwardRefRenderFunction, JSXElementConstructor } from 'react'
import type { BetterVariantProps, CSS, VariantProps } from './stitches.config'
import type { O } from 'ts-toolbelt'

export type Override<T extends object, U extends object> = O.Merge<T, U, 'flat'>
export type IntrinsicElementsKeys = keyof JSX.IntrinsicElements

type ExtraStaticProps = {
  id: string
  displayName: string
  toString: () => string
  className: string
}

type AsProps = IntrinsicElementsKeys | React.ComponentType | React.ElementType
export interface ExtendedFC<P> extends React.FC<P> {
  id?: string
  displayName?: string
  // Used for stitches
  toString?: () => string
  className?: string
}

export function withId<T>(displayName: string, _component: T) {
  const component = _component as T & ExtraStaticProps
  component.id = displayName
  component.displayName = displayName
  component.toString = () => `.${displayName}`
  component.className = displayName

  return component
}

export function withStatic<T, A extends Record<string, unknown> = { id: string }>(comp: T, attr: A) {
  const component = comp as T & A
  Object.keys(attr).forEach((k) => {
    ;(component as any)[k] = attr[k]
  })

  return component
}

export function withStyled<
  StyledPrimitiveType extends JSXElementConstructor<any>,
  AdditionalProps extends object,
  CommonProps = { css?: CSS; as?: AsProps },
  ComponentProps = Override<
    Override<React.ComponentProps<StyledPrimitiveType>, VariantProps<StyledPrimitiveType>>,
    AdditionalProps
  > &
    CommonProps
>(displayName: string, componentFn: FC<ComponentProps>) {
  componentFn.displayName = `Internal${displayName}`
  componentFn.toString = () => `.Internal${displayName}`

  return withStatic(componentFn, {
    id: displayName,
    displayName,
    toString: () => `.${displayName}`,
    className: displayName,
  })
}

export function withStyledRef<
  StyledPrimitiveType extends JSXElementConstructor<any>,
  AdditionalProps extends object,
  CommonProps = { css?: CSS; as?: AsProps },
  ComponentProps = Override<
    Override<React.ComponentProps<StyledPrimitiveType>, VariantProps<StyledPrimitiveType>>,
    AdditionalProps
  > &
    CommonProps
>(displayName: string, componentFn: ForwardRefRenderFunction<ElementRef<StyledPrimitiveType>, ComponentProps>) {
  componentFn.displayName = `Internal${displayName}`
  componentFn.toString = () => `.Internal${displayName}`

  return withStatic(React.forwardRef<React.ElementRef<StyledPrimitiveType>, ComponentProps>(componentFn), {
    id: displayName,
    displayName,
    toString: () => `.${displayName}`,
    className: displayName,
  })
}

interface ExtendedReactElement extends React.ReactElement {
  type: React.ReactElement['type'] & Partial<ExtraStaticProps>
}

export function getValidChildren(children: React.ReactNode) {
  return React.Children.toArray(children).filter((child) => React.isValidElement(child)) as ExtendedReactElement[]
}

export type mapFuncHelpers = {
  index: number
  first: boolean
  last: boolean
  firstOfType: boolean
  lastOfType: boolean
  nOfType: number
  childType: string
  childTypes: Record<string, number>
}

export type mapChildrenFunc = (
  child: ExtendedReactElement,
  helpers: mapFuncHelpers
) => React.ReactNode | ExtendedReactElement

export const getChildType = (child: ExtendedReactElement) => child.type.id || `${child.type}`

export function mapChildren(children: React.ReactNode, mapFunc: mapChildrenFunc) {
  const validChildren: ExtendedReactElement[] = []
  const childTypes: Record<string, number> = {}
  const childTypeCounter: Record<string, number> = {}

  ;(React.Children.toArray(children) as ExtendedReactElement[]).forEach((child: ExtendedReactElement) => {
    if (React.isValidElement(child)) {
      validChildren.push(child)
      const childType = getChildType(child)
      childTypes[childType] = (childTypes[childType] || 0) + 1
    }
  })

  return validChildren.map((child, index) => {
    const childType = getChildType(child)
    const nOfType = (childTypeCounter[childType] || 0) + 1
    childTypeCounter[childType] = nOfType

    return mapFunc(child, {
      index,
      first: index === 0,
      last: index + 1 === validChildren.length,
      nOfType,
      firstOfType: nOfType === 1,
      lastOfType: nOfType === childTypes[childType],
      childType,
      childTypes,
    })
  }) as ExtendedReactElement[]
}
