import { Box, CircularProgress, TextField } from '@mui/material'
import { DataFrame } from 'danfojs'
import { Annotations, Data, Datum, PlotRelayoutEvent } from 'plotly.js'
import { ChangeEvent, useEffect, useMemo, useState } from 'react'
import Plot from 'react-plotly.js'
import { DifferentialExpressionResult } from '../../../../model/analysisCommands'
import GeneListSelector from './GeneListSelector'
import { ArrayType1D, ArrayType2D } from 'danfojs/dist/danfojs-base/shared/types'
import { getMaxAbsoluteValue } from '../../../../utils/misc'

/*
 * https://plotly.com/python/reference/heatmap/#heatmap-hovertemplate
 */
export interface HoverInfo {
    xlabel?: string
    ylabel?: string
    zlabel?: string
    customData?: Datum[][] | null
    customDataLabel?: string
}

export interface HeatmapPlotParams {
    plotDataFrame: DataFrame
    width: number
    height: number
    xlabel?: string
    xangle?: number
    ylabel?: string
    yangle?: number
    onRelayout?: (e: Readonly<PlotRelayoutEvent>) => void
    title?: string
    colorbarTitle?: string
    mapPlotColorValues?: (row: ArrayType2D) => ArrayType1D
    mapPlotLabelValues?: (cell: ArrayType1D) => string
    hoverInfo?: HoverInfo
    symmetricalColorScale?: boolean
}

export function HeatmapPlot({
    plotDataFrame,
    width,
    height,
    xlabel,
    xangle,
    ylabel,
    yangle,
    onRelayout,
    title,
    colorbarTitle,
    mapPlotColorValues,
    mapPlotLabelValues,
    hoverInfo,
    symmetricalColorScale,
}: HeatmapPlotParams) {
    const plotData = useMemo((): Data[] => {
        if (!plotDataFrame) {
            return []
        }

        const z = mapPlotColorValues
            ? plotDataFrame.values.map((row) => mapPlotColorValues(row as unknown as ArrayType2D))
            : plotDataFrame.values

        const customdata =
            hoverInfo &&
            hoverInfo.customData &&
            hoverInfo.customData.length > 0 &&
            hoverInfo.customData.length === plotDataFrame.shape[0] &&
            hoverInfo.customData[0].length === plotDataFrame.shape[1]
                ? hoverInfo.customData
                : undefined

        const hovertemplate = hoverInfo
            ? [
                  '<extra></extra>', // Disable the secondary box
                  `${hoverInfo.xlabel ?? 'x'}: %{x}`,
                  `${hoverInfo.ylabel ?? 'y'}: %{y}`,
                  `${hoverInfo.zlabel ?? 'z'}: %{z}`,
                  customdata ? `${hoverInfo.customDataLabel ?? 't'}: %{customdata}` : undefined,
              ]
                  .filter((v) => v !== undefined)
                  .join('<br>')
            : undefined

        const data = [
            {
                x: plotDataFrame.columns,
                y: plotDataFrame.index,
                z,
                type: 'heatmap',
                hoverongaps: false,
                colorbar: {
                    thickness: 10,
                    xanchor: 'left',
                    title: colorbarTitle,
                },
                customdata,
                hovertemplate,
            } as Data,
        ]

        if (mapPlotColorValues || symmetricalColorScale) {
            const N = getMaxAbsoluteValue(z)
            // @ts-expect-error property exists
            data[0].zmin = -N
            // @ts-expect-error property exists
            data[0].zmax = N
        }

        return data
    }, [plotDataFrame, colorbarTitle, mapPlotColorValues])

    const plotAnnotationData = useMemo((): Partial<Annotations>[] => {
        if (!plotDataFrame || !mapPlotLabelValues) {
            return []
        }

        return plotDataFrame.index.flatMap((index) =>
            plotDataFrame.columns.map((column) => {
                return {
                    xref: 'x',
                    yref: 'y',
                    x: column,
                    y: index,
                    showarrow: false,
                    text: mapPlotLabelValues(plotDataFrame.at(index, column) as unknown as ArrayType1D),
                    font: {
                        color: 'white',
                    },
                } as Partial<Annotations>
            }),
        )
    }, [plotDataFrame, mapPlotLabelValues])

    const plotLeftMargin = useMemo(() => {
        if (
            !plotDataFrame ||
            plotDataFrame.shape[0] == 0 ||
            !plotDataFrame.columns.includes(plotDataFrame.columns[0])
        ) {
            return 60
        }
        return Math.max(5 * Math.max(...plotDataFrame.index.map((v) => v.toString().length)), 60)
    }, [plotDataFrame])

    const tickMode = useMemo((): 'auto' | 'linear' => {
        if (!plotDataFrame || plotDataFrame.shape[1] > 50) {
            return 'auto'
        }
        return 'linear'
    }, [plotDataFrame])

    return (
        <Plot
            data={plotData}
            layout={{
                title: title ?? '',
                width: width,
                height: height,
                xaxis: {
                    title: {
                        text: xlabel ?? 'Genes',
                        font: {
                            size: 12,
                        },
                    },
                    tickfont: {
                        size: 8,
                        color: 'black',
                    },
                    tickmode: tickMode,
                    type: 'category',
                    showgrid: false,
                    zeroline: false,
                    showline: false,
                    ...(xangle && { tickangle: xangle }),
                },
                yaxis: {
                    title: {
                        text: ylabel ?? 'Feature',
                        font: {
                            size: 12,
                        },
                    },
                    type: 'category',
                    tickfont: {
                        size: 8,
                        color: 'black',
                    },
                    showgrid: false,
                    zeroline: false,
                    showline: false,
                    ...(yangle && { tickangle: yangle }),
                },
                margin: {
                    t: 35,
                    r: 0,
                    l: plotLeftMargin,
                },
                annotations: plotAnnotationData,
            }}
            config={{
                scrollZoom: false,
                displaylogo: false,
                displayModeBar: true,
                modeBarButtonsToRemove: ['lasso2d', 'select2d', 'pan2d', 'zoomIn2d', 'zoomOut2d', 'autoScale2d'],
                toImageButtonOptions: {
                    format: 'svg',
                    filename: 'heatmap',
                    scale: 2,
                },
            }}
            onRelayout={onRelayout}
        />
    )
}

export interface DEGHeatmapPlotParams extends HeatmapPlotParams {
    comparison: DifferentialExpressionResult
    degsDataFrame: DataFrame
}

export function DEGHeatmapPlot({
    plotDataFrame,
    comparison,
    degsDataFrame,
    width,
    height,
    xlabel,
    ylabel,
    hoverInfo,
    symmetricalColorScale,
}: DEGHeatmapPlotParams) {
    const [genes, setGenes] = useState<string[]>([])
    const [pval, setPval] = useState<string>(comparison.pValueCutoff.toString())
    const [logFoldChange, setLogFoldChange] = useState<string>(comparison.minAbsLogFC.toString())
    const [currentDf, setCurrentDf] = useState<DataFrame | null>(null)

    useEffect(() => {
        if (!plotDataFrame) {
            return
        }
        setPval(comparison.pValueCutoff.toString())
        setLogFoldChange(comparison.minAbsLogFC.toString())
        const filteredGenes = degsDataFrame
            .iloc({
                rows: degsDataFrame['pvals_adj']
                    .le(comparison.pValueCutoff)
                    .and(
                        degsDataFrame['logfoldchanges']
                            .le(-comparison.minAbsLogFC)
                            .or(degsDataFrame['logfoldchanges'].ge(comparison.minAbsLogFC)),
                    ),
            })
            .column('names').values as string[]
        const allColumns = plotDataFrame.columns
        const filteredColumns = new Set([allColumns[0]])
        allColumns.forEach((col) => {
            if (filteredGenes.includes(col)) {
                filteredColumns.add(col)
            }
        })
        setCurrentDf(plotDataFrame.loc({ columns: Array.from(filteredColumns) }))
    }, [plotDataFrame, comparison])

    useEffect(() => {
        filterByGenes()
    }, [genes])

    const handlePvalChange = (event: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>) => {
        setPval(event.target.value)
    }

    const handleLogFoldChangeChange = (event: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>) => {
        setLogFoldChange(event.target.value)
    }

    const filterByGenes = () => {
        // if no filters are set, do nothing
        if (!genes && !pval && !logFoldChange) {
            return
        }

        // filter by p-value and/or logfoldchange
        let filteredDegsDf = degsDataFrame.copy()
        const logFoldChangeFloat = parseFloat(logFoldChange)
        if (!isNaN(logFoldChangeFloat)) {
            filteredDegsDf = filteredDegsDf
                .iloc({
                    rows: filteredDegsDf['logfoldchanges']
                        .ge(logFoldChangeFloat)
                        .or(filteredDegsDf['logfoldchanges'].le(-logFoldChangeFloat)),
                })
                .resetIndex()
        }
        const pvalFloat = parseFloat(pval)
        if (!isNaN(pvalFloat)) {
            filteredDegsDf = filteredDegsDf
                .iloc({
                    rows: filteredDegsDf['pvals_adj'].le(pvalFloat),
                })
                .resetIndex()
        }
        // these are the genes that pass the p-value and log fold change thresholds
        const thresholdGenes = filteredDegsDf['names'].values

        let finalGenes: string[] = []
        if (genes.length > 0) {
            // we need to get the intersection with threshold genes
            finalGenes = genes.filter((g) => thresholdGenes.includes(g))
        } else {
            // we use the threshold genes
            finalGenes = thresholdGenes
        }

        const uniqueColumns = new Set<string>(plotDataFrame.columns)
        setCurrentDf(
            plotDataFrame.loc({
                columns: finalGenes.filter((fg) => uniqueColumns.has(fg)),
            }),
        )
    }

    const sampleGenes = useMemo(() => {
        return plotDataFrame.columns.slice(1, 10).join(', ')
    }, [plotDataFrame])

    const onRelayout = (e: Readonly<PlotRelayoutEvent>) => {
        if (e['xaxis.autorange'] && e['yaxis.autorange']) {
            setCurrentDf(plotDataFrame)
        }
    }

    return (
        <Box>
            <Box
                sx={{
                    display: 'flex',
                    width: '100%',
                    '.MuiTextField-root': {
                        mr: 1,
                    },
                }}
            >
                <GeneListSelector
                    sx={{
                        width: '600px',
                        mr: 1,
                    }}
                    handleGenesChange={(genes) => setGenes(genes)}
                    placeHolderGenes={sampleGenes}
                />
                <TextField
                    label={'Max p-value'}
                    sx={{ width: '100px' }}
                    variant='outlined'
                    placeholder={`E.g.: 0.01`}
                    onChange={handlePvalChange}
                    onKeyDown={(e) => {
                        if (e.key == 'Enter') {
                            filterByGenes()
                        }
                    }}
                    value={pval?.toString()}
                />
                <TextField
                    label={'Min |logFC|'}
                    sx={{ width: '100px' }}
                    variant='outlined'
                    placeholder={`E.g.: 1`}
                    onChange={handleLogFoldChangeChange}
                    onKeyDown={(e) => {
                        if (e.key == 'Enter') {
                            filterByGenes()
                        }
                    }}
                    value={logFoldChange?.toString()}
                />
            </Box>
            {currentDf ? (
                <HeatmapPlot
                    plotDataFrame={currentDf}
                    width={width}
                    height={height}
                    onRelayout={onRelayout}
                    xlabel={xlabel}
                    ylabel={ylabel}
                    symmetricalColorScale={symmetricalColorScale}
                    hoverInfo={hoverInfo}
                />
            ) : (
                <CircularProgress size={30} />
            )}
        </Box>
    )
}
