import { Box, Flex, Icon, Table, TableContainer, Tbody, Td, Text, Th, Thead, Tr } from '@chakra-ui/react';
import {
  Column,
  ColumnDef,
  ColumnSort,
  FilterFn,
  PaginationState,
  Row,
  RowData,
  SortingState,
  flexRender,
  getCoreRowModel,
  getFilteredRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  useReactTable,
} from '@tanstack/react-table';
import classNames from 'classnames';
import noop from 'lodash/noop';
import { CSSProperties, RefCallback, RefObject, useEffect, useRef, useState } from 'react';

import theme from '../../theme';
import { SortAscendingIcon, SortDescendingIcon, SortIndeterminateIcon } from '../Icon';
import { Pagination } from './Pagination';

import classes from './DataTable.module.scss';

/**
 * Per the documentation, extend the library's typings to support a new `getRowStyles` prop. Note that since the
 * library's typings use `TData` and `RowData` as names, these are fixed and must be used when extended. After this
 * block, our own typings should be used -- i.e. `Data` below.
 * @see https://tanstack.com/table/v8/docs/api/core/table#meta
 * @see https://github.com/TanStack/table/issues/44
 */
declare module '@tanstack/react-table' {
  interface TableMeta<TData extends RowData> {
    getRowStyles?: (row: Row<TData>) => CSSProperties;
  }
}

/**
 * Get styles for pinned columns, based on whether the column is pinned or not.
 * This could technically be done via conditional classnames, but requires a few
 * checks/values that would need to be passed down as well, so I've kept it as a
 * function in the component.
 */
const getColumnPinningStyles = <TData extends RowData>(column: Column<TData>): CSSProperties => {
  const isPinned = column.getIsPinned();
  const isLastLeftPinnedColumn = isPinned === 'left' && column.getIsLastColumn('left');
  const isFirstRightPinnedColumn = isPinned === 'right' && column.getIsFirstColumn('right');

  return {
    boxShadow: isLastLeftPinnedColumn
      ? `inset -1px 0 0 0 ${theme.colors.brand.gray[100]}`
      : isFirstRightPinnedColumn
      ? `0.25rem 0 0.25rem -0.25rem ${theme.colors.brand.gray[100]}`
      : undefined,
    left: isPinned === 'left' ? `${column.getStart('left')}px` : undefined,
    right: isPinned === 'right' ? `${column.getAfter('right')}px` : undefined,
    position: isPinned ? 'sticky' : 'relative',
    width: column.getSize(),
    zIndex: isPinned ? 20 : 0,
  };
};

export type DataTableProps<Data extends RowData> = {
  /** List of columns to render in the table. */
  columns: ColumnDef<Data, any>[];
  /** List of data to render in the table. */
  data: (Data & { key?: string })[];
  /**
   * Default sorting object (column id and direction).
   * @see https://tanstack.com/table/v8/docs/api/features/sorting
   */
  defaultSort?: ColumnSort;
  /**
   * Optional function that allows the developer to specify a value to be used as the row ID. Usually this would be some
   * property from the row.
   */
  getRowId?: (row: Data, index: number, parent?: Row<Data>) => string;
  /** Optional function to override the style of a given row. */
  getRowStyles?: (row: Row<Data>) => CSSProperties;
  /**
   * The value of the global filter, to be tracked internally by the table state. Provide this value when using a global
   * filter (such as a high-level text input) which filters all columns at once.
   */
  globalFilter?: any;
  /** Filter function applied globally to the table. Filters across all columns at once. */
  globalFilterFn?: FilterFn<Data>;
  /**
   * Whether or not the table should be rendered with compact headers. As long as the data's rendered content is either
   * thin or responsive enough, the column should be able to render at a low width.
   */
  isCompact?: boolean;
  /** Message to show when the table has no data. */
  noDataMessage?: string;
  /** Optional function to call when a row is clicked. */
  onRowClick?: (row: Data) => void;
  /** The page size (for pagination). Not needed if `shouldUsePagination` is `false`. Defaults to 10. */
  pageSize?: number;
  /** Unit for labeling the pagination object */
  paginationRowUnit?: string;
  /** The ID of the pinned column. Must be an ID of a column in the table. */
  pinnedColumnId?: string;
  /**
   * An optional row to pin at the top of the table. Currently requires a key to properly pin, but this should be
   * refactored.
   */
  pinnedRow?: Data & { key?: string };
  /** Whether or not the table should use pagination. Defaults to `false`. */
  shouldUsePagination?: boolean;
  /**
   * React ref object for the table container element. Allows for external access to scrolling and other properties of
   * the table container element.
   */
  tableContainerRef?: RefCallback<HTMLDivElement> | RefObject<HTMLDivElement>;
};

export const DataTable = <Data extends RowData>({
  columns,
  data,
  defaultSort,
  getRowId,
  getRowStyles,
  globalFilter,
  globalFilterFn,
  isCompact,
  noDataMessage = 'No results were found for your search.',
  onRowClick,
  pageSize = 10,
  paginationRowUnit = '',
  pinnedColumnId,
  pinnedRow,
  shouldUsePagination = false,
  tableContainerRef,
}: DataTableProps<Data>) => {
  const [sorting, setSorting] = useState<SortingState>(defaultSort ? [defaultSort] : []);
  const [pagination, setPagination] = useState<PaginationState>({
    pageIndex: 0,
    pageSize,
  });

  const tableRef = useRef<HTMLTableElement>(null);

  const [tableHeadHeight, setTableHeadHeight] = useState(0);

  const tableHeadRef = useRef<HTMLTableSectionElement>(null);

  useEffect(() => {
    if (!tableHeadRef.current) {
      return undefined;
    }

    const resizeObserver = new ResizeObserver(() => {
      const headerHeight = Math.floor(tableHeadRef.current?.getBoundingClientRect().height ?? 0);
      if (headerHeight !== tableHeadHeight) {
        setTableHeadHeight(Math.floor(headerHeight));
      }
    });

    resizeObserver.observe(tableHeadRef.current);

    return () => {
      resizeObserver.disconnect();
    };
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [tableHeadRef.current]);

  const TopPinnedRow = ({ row }: { row: Row<any> }) => (
    <Tr
      backgroundColor={theme.colors.brand.gray[50]}
      height="100%"
      position="sticky"
      top={row.getIsPinned() === 'top' ? `${tableHeadHeight}px` : undefined}
      zIndex={40}
      _hover={{
        backgroundColor: theme.colors.brand.gray[100],
      }}
    >
      {row.getVisibleCells().map((cell) => {
        return (
          <Td
            key={cell.id}
            style={{
              ...getColumnPinningStyles(cell.column),
              // Need to use box shadow instead of border to preserve the "border" on scroll. Append the new box shadow to the existing one, if one exists.
              boxShadow: getColumnPinningStyles(cell.column).boxShadow
                ? `${getColumnPinningStyles(cell.column).boxShadow}, inset 0 -1px 0 ${theme.colors.brand.gray[100]}`
                : `inset 0 -1px 0 ${theme.colors.brand.gray[100]}`,
              border: 'none',
            }}
            height="100%"
            backgroundColor={table.options.meta?.getRowStyles?.(row).backgroundColor ?? theme.colors.brand.gray[50]}
          >
            {flexRender(cell.column.columnDef.cell, cell.getContext())}
          </Td>
        );
      })}
    </Tr>
  );

  /**
   * Scroll to the top of the table element. Used when clicking pagination buttons.
   */
  const scrollToTableTop = () => {
    tableRef.current?.scrollTo({ top: 0 });
  };

  const tableData = data;

  if (pinnedRow && tableData.length > 0) {
    tableData.push(pinnedRow);
  }

  /**
   * Note that in the hook below, the `autoResetPageIndex` option was causing opening drawers to reset the pagination
   * state.
   * @see https://tanstack.com/table/v8/docs/guide/pagination#auto-reset-page-index
   */
  const table = useReactTable({
    autoResetPageIndex: false,
    columns,
    data: tableData,
    getCoreRowModel: getCoreRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    onSortingChange: (sortingState) => {
      setSorting(sortingState);
      setPagination({ ...pagination, pageIndex: 0 });
    },
    getSortedRowModel: getSortedRowModel(),
    onPaginationChange: setPagination,
    getRowId,
    getPaginationRowModel: shouldUsePagination ? getPaginationRowModel() : undefined,
    initialState: {
      columnPinning: {
        left: pinnedColumnId ? [pinnedColumnId] : [],
      },
      rowPinning: {
        top: pinnedRow && pinnedRow.key ? [pinnedRow.key] : [],
      },
    },
    state: {
      sorting,
      globalFilter,
      pagination,
    },
    globalFilterFn,
    meta: {
      getRowStyles,
    },
  });

  return (
    <Box width="100%" display="flex" flexDirection="column" overflow="hidden" borderRadius="0.5rem">
      <TableContainer
        className={classNames(classes.dataTableContainer)}
        data-testid="table-container"
        ref={tableContainerRef}
      >
        <Table
          backgroundColor={theme.colors.white}
          borderRadius="0.5rem"
          className={classNames(classes.dataTable)}
          height="fit-content"
          ref={tableRef}
        >
          <Thead
            backgroundColor={theme.colors.brand.gray[50]}
            position="sticky"
            ref={tableHeadRef}
            top={0}
            userSelect="none"
          >
            {table.getCenterRows().length > 0 &&
              table.getHeaderGroups().map((headerGroup) => (
                <Tr key={headerGroup.id}>
                  {headerGroup.headers.map((header) => {
                    // see https://tanstack.com/table/v8/docs/api/core/column-def#meta to type this correctly
                    const meta: any = header.column.columnDef.meta;
                    let sortDirection: 'ASC' | 'DESC' | undefined;

                    if (header.column.getIsSorted() === 'asc') {
                      sortDirection = 'ASC';
                    } else if (header.column.getIsSorted() === 'desc') {
                      sortDirection = 'DESC';
                    }

                    return (
                      <Th
                        key={header.id}
                        backgroundColor={theme.colors.brand.gray[50]}
                        border="none"
                        className={classNames(classes.header, { [classes.sortableHeader]: header.column.getCanSort() })}
                        isNumeric={meta?.isNumeric}
                        onClick={header.column.getToggleSortingHandler()}
                        padding={isCompact ? '0.875rem 0.5rem' : '0.875rem 1rem'}
                        style={getColumnPinningStyles(header.column)}
                      >
                        <Flex
                          alignItems="center"
                          color={theme.colors.brand.gray[600]}
                          fontWeight="bold"
                          justifyContent="space-between"
                          width="100%"
                        >
                          {flexRender(header.column.columnDef.header, header.getContext())}
                          {header.column.getCanSort() && (
                            <Icon
                              aria-hidden
                              as={
                                sortDirection === 'ASC'
                                  ? SortAscendingIcon
                                  : sortDirection === 'DESC'
                                  ? SortDescendingIcon
                                  : SortIndeterminateIcon
                              }
                              height="1rem"
                              marginInlineStart="0.25rem"
                              width="1rem"
                              direction={sortDirection}
                            />
                          )}
                        </Flex>
                      </Th>
                    );
                  })}
                </Tr>
              ))}
          </Thead>
          <Tbody>
            {table.getCenterRows().length > 0 &&
              table.getTopRows().map((row) => <TopPinnedRow key={row.id} row={row} />)}
            {table.getCenterRows().map((row) => (
              <Tr
                key={row.id}
                cursor={onRowClick instanceof Function ? 'pointer' : 'default'}
                height="100%"
                // if onRowClick is defined, call it with the data for the clicked row (`row.original`)
                onClick={onRowClick instanceof Function ? () => onRowClick(row.original) : noop}
                style={table.options.meta?.getRowStyles?.(row)}
                _hover={{
                  backgroundColor: theme.colors.brand.gray[50],
                }}
              >
                {row.getVisibleCells().map((cell) => {
                  // see https://tanstack.com/table/v8/docs/api/core/column-def#meta to type this correctly
                  const meta: any = cell.column.columnDef.meta;

                  return (
                    <Td
                      key={cell.id}
                      backgroundColor={table.options.meta?.getRowStyles?.(row).backgroundColor ?? theme.colors.white}
                      height="100%"
                      isNumeric={meta?.isNumeric}
                      style={getColumnPinningStyles(cell.column)}
                    >
                      {flexRender(cell.column.columnDef.cell, cell.getContext())}
                    </Td>
                  );
                })}
              </Tr>
            ))}
            {(table.getRowModel().rows.length === 0 || table.getAllColumns().length === 0) && (
              <Tr>
                <Td colSpan={columns.length}>
                  <Text
                    className={classes.emptyList}
                    color={theme.colors.brand.gray[600]}
                    padding="1.5rem"
                    textAlign="center"
                    textStyle="detail"
                  >
                    {noDataMessage}
                  </Text>
                </Td>
              </Tr>
            )}
          </Tbody>
        </Table>
      </TableContainer>
      {shouldUsePagination && (
        <Flex alignItems="flex-end" flexDirection="column" marginTop="0.5rem">
          <Pagination
            pageCount={table.getPageCount()}
            displayedRowCount={table.getRowModel().rows.length}
            totalRowCount={data.length}
            setPageIndex={(index) => {
              table.setPageIndex(index);
              scrollToTableTop();
            }}
            pageIndex={table.getState().pagination.pageIndex}
            canPreviousPage={table.getCanPreviousPage()}
            goToPreviousPage={() => {
              table.previousPage();
              scrollToTableTop();
            }}
            canNextPage={table.getCanNextPage()}
            goToNextPage={() => {
              table.nextPage();
              scrollToTableTop();
            }}
            rowUnit={paginationRowUnit}
          />
        </Flex>
      )}
    </Box>
  );
};
