import { Box, TextField } from '@mui/material'
import { DataFrame } from 'danfojs'
import { Data, PlotRelayoutEvent, PlotSelectionEvent, Shape } from 'plotly.js'
import React, { ChangeEvent, useEffect, useMemo, useState } from 'react'
import Plot from 'react-plotly.js'
import { DifferentialAnalysisThresholds } from '../../../../model/analysisCommands'
import { getDefaultMarkerSize } from '../../../../utils/misc'
import GeneListSelector from './GeneListSelector'
import MarkerSizeControl from './MarkerSizeControl'
import { useAppDispatch } from '../../../../app/hooks'
import { receivedLassoSelections } from '../analysisDataHubSlice'

export interface VolcanoPlotParams {
    comparison: DifferentialAnalysisThresholds
    df: DataFrame
    width: number
    height: number
    disableLasso?: boolean
    genePlot?: boolean
    dataFieldName?: string
}

export default function VolcanoPlot({
    comparison,
    df,
    width,
    height,
    disableLasso,
    genePlot = true,
    dataFieldName = 'names',
}: VolcanoPlotParams) {
    const dispatch = useAppDispatch()
    const [markerSize, setMarkerSize] = useState(0)
    const [currentDf, setCurrentDf] = useState<DataFrame | null>(null)
    const [pval, setPval] = useState<string>(comparison.pValueCutoff.toString())
    const [logFoldChange, setLogFoldChange] = useState<string>(comparison.minAbsLogFC.toString())
    const [genes, setGenes] = useState<string[]>([])
    // this is a hack to redraw the plot and exit the selection state
    const [synthetic, setSynthetic] = useState(0)

    const boundedDf = useMemo(() => {
        let logQMax = Number.MIN_SAFE_INTEGER
        df.column('-logQ').values.forEach((v) => {
            if (typeof v === 'number') {
                logQMax = Math.max(logQMax, v)
            }
        })

        return df.replace('inf', logQMax + 1, { columns: ['-logQ'] }).asType('-logQ', 'float32')
    }, [df])

    useEffect(() => {
        if (!boundedDf) {
            return
        }

        setCurrentDf(boundedDf)
        setMarkerSize(getDefaultMarkerSize(boundedDf.shape[0]))
        setPval(comparison.pValueCutoff.toString())
        setLogFoldChange(comparison.minAbsLogFC.toString())
    }, [boundedDf, comparison])

    useEffect(() => {
        if (genePlot) {
            filterByGenes()
        }
    }, [genes, genePlot])

    const plotData = useMemo(() => {
        if (!currentDf || !pval || !logFoldChange) {
            return [] as Data[]
        }
        const pvalFloat = parseFloat(pval)
        const logFoldChangeFloat = parseFloat(logFoldChange)
        if (isNaN(pvalFloat) || isNaN(logFoldChangeFloat)) {
            return [] as Data[]
        }
        const minusLogQCutoff = -Math.log10(pvalFloat)
        const highPlotDataFrame = currentDf.query(
            currentDf['-logQ']
                .ge(minusLogQCutoff)
                .and(
                    currentDf['logfoldchanges']
                        .ge(logFoldChangeFloat)
                        .or(currentDf['logfoldchanges'].le(-logFoldChangeFloat)),
                ),
        )
        const lowPlotDataFrame = currentDf.query(
            currentDf['-logQ']
                .lt(minusLogQCutoff)
                .or(
                    currentDf['logfoldchanges']
                        .lt(logFoldChangeFloat)
                        .and(currentDf['logfoldchanges'].gt(-logFoldChangeFloat)),
                ),
        )

        const showText = currentDf.shape[0] <= 40

        return [
            {
                type: 'scattergl',
                mode: showText ? 'text+markers' : 'markers',
                x: lowPlotDataFrame['logfoldchanges'].values,
                y: lowPlotDataFrame['-logQ'].values,
                text: lowPlotDataFrame[dataFieldName].values,
                textposition: 'bottom center',
                marker: {
                    size: markerSize,
                    color: '#0288d1',
                },
            },
            {
                type: 'scattergl',
                mode: showText ? 'text+markers' : 'markers',
                x: highPlotDataFrame['logfoldchanges'].values,
                y: highPlotDataFrame['-logQ'].values,
                text: highPlotDataFrame[dataFieldName].values,
                textposition: 'bottom center',
                marker: {
                    size: markerSize,
                    color: '#D45500',
                },
            },
        ] as Data[]
    }, [currentDf, pval, logFoldChange, markerSize, synthetic])

    const plotShapes = useMemo(() => {
        const shapes = [] as Shape[]
        if (pval) {
            const pvalFloat = parseFloat(pval)
            if (!isNaN(pvalFloat)) {
                shapes.push({
                    type: 'line',
                    xref: 'paper',
                    x0: 0,
                    y0: -Math.log10(pvalFloat),
                    x1: 1,
                    y1: -Math.log10(pvalFloat),
                    line: {
                        color: '#848484',
                        width: 2,
                        dash: 'dot',
                    },
                } as Shape)
            }
        }
        if (logFoldChange) {
            const logFoldChangeFloat = parseFloat(logFoldChange)
            if (!isNaN(logFoldChangeFloat)) {
                shapes.push({
                    type: 'line',
                    xref: 'x',
                    yref: 'paper',
                    x0: -logFoldChangeFloat,
                    y0: 0,
                    x1: -logFoldChangeFloat,
                    y1: 1,
                    line: {
                        color: '#848484',
                        width: 2,
                        dash: 'dot',
                    },
                } as Shape)
                shapes.push({
                    type: 'line',
                    xref: 'x',
                    yref: 'paper',
                    x0: logFoldChangeFloat,
                    y0: 0,
                    x1: logFoldChangeFloat,
                    y1: 1,
                    line: {
                        color: '#848484',
                        width: 2,
                        dash: 'dot',
                    },
                } as Shape)
            }
        }
        return shapes
    }, [pval, logFoldChange])

    const onRelayout = (e: Readonly<PlotRelayoutEvent>) => {
        if (!currentDf) {
            return
        }
        const xRange0 = e['xaxis.range[0]'],
            xRange1 = e['xaxis.range[1]'],
            yRange0 = e['yaxis.range[0]'],
            yRange1 = e['yaxis.range[1]']

        let newCurrentDf = null
        if (xRange0 && xRange1 && yRange0 && yRange1) {
            newCurrentDf = currentDf
                .iloc({
                    rows: currentDf['logfoldchanges']
                        .ge(xRange0)
                        .and(currentDf['logfoldchanges'].le(xRange1))
                        .and(currentDf['-logQ'].ge(yRange0))
                        .and(currentDf['-logQ'].le(yRange1)),
                })
                .resetIndex()
            setCurrentDf(newCurrentDf)
        } else if (xRange0 && xRange1) {
            newCurrentDf = currentDf
                .iloc({
                    rows: currentDf['logfoldchanges'].ge(xRange0).and(currentDf['logfoldchanges'].le(xRange1)),
                })
                .resetIndex()
            setCurrentDf(newCurrentDf)
        } else if (yRange0 && yRange1) {
            newCurrentDf = currentDf
                .iloc({
                    rows: currentDf['-logQ'].ge(yRange0).and(currentDf['-logQ'].le(yRange1)),
                })
                .resetIndex()
            setCurrentDf(newCurrentDf)
        } else if (e['xaxis.autorange'] && e['yaxis.autorange']) {
            setCurrentDf(boundedDf)
        }
    }

    const handlePvalChange = (event: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>) => {
        setPval(event.target.value)
    }

    const handleLogFoldChangeChange = (event: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>) => {
        setLogFoldChange(event.target.value)
    }

    const filterByGenes = () => {
        if (!genes || genes.length < 1) {
            setCurrentDf(boundedDf)
            return
        }

        const filteredDf = boundedDf
            .loc({
                rows: boundedDf['names'].values.map((x: string) => genes.includes(x)),
            })
            .resetIndex()
            .copy()
        setCurrentDf(filteredDf)
    }

    const sampleGenes = useMemo(() => {
        return boundedDf[dataFieldName].values.slice(0, 10).join(', ')
    }, [boundedDf])

    const onSelected = (e: Readonly<PlotSelectionEvent>) => {
        // @ts-expect-error Only exists sometimes.
        if (e && e.selections && e.selections.length > 0) {
            const selection = e.points.map((p) => p.text)
            dispatch(receivedLassoSelections({ lassoSelections: [selection] }))
        }
    }

    return (
        plotData && (
            <Box>
                <Box sx={{ display: 'flex', justifyContent: 'space-between' }}>
                    <Box
                        sx={{
                            display: 'flex',
                            '.MuiTextField-root': {
                                mr: 1,
                            },
                        }}
                    >
                        {genePlot && (
                            <GeneListSelector
                                sx={{
                                    width: '500px',
                                    mr: 1,
                                }}
                                placeHolderGenes={sampleGenes}
                                handleGenesChange={(genes) => setGenes(genes)}
                            />
                        )}

                        <TextField
                            label={'Max p-value'}
                            sx={{ width: '100px' }}
                            variant='outlined'
                            placeholder={`E.g.: 0.01`}
                            onChange={handlePvalChange}
                            onKeyDown={(e) => {
                                if (e.key == 'Enter') {
                                    filterByGenes()
                                }
                            }}
                            value={pval}
                        />
                        <TextField
                            label={'Min |logFC|'}
                            sx={{ width: '100px' }}
                            variant='outlined'
                            placeholder={`E.g.: 1`}
                            onChange={handleLogFoldChangeChange}
                            onKeyDown={(e) => {
                                if (e.key == 'Enter') {
                                    filterByGenes()
                                }
                            }}
                            value={logFoldChange}
                        />
                    </Box>
                    <MarkerSizeControl markerSize={markerSize} setMarkerSize={setMarkerSize} />
                </Box>
                <Plot
                    data={plotData}
                    layout={{
                        title: {
                            font: {
                                size: 12,
                            },
                        },
                        width: width,
                        height: height,
                        xaxis: {
                            title: {
                                text: 'log FC',
                                font: {
                                    size: 12,
                                },
                            },
                        },
                        yaxis: {
                            title: {
                                text: '-log Q-value',
                                font: {
                                    size: 12,
                                },
                            },
                        },
                        margin: {
                            t: 35,
                            r: 10,
                            b: 40,
                        },
                        shapes: plotShapes,
                    }}
                    config={{
                        scrollZoom: false,
                        displaylogo: false,
                        modeBarButtonsToRemove: disableLasso
                            ? ['lasso2d', 'select2d', 'zoomIn2d', 'zoomOut2d', 'autoScale2d']
                            : ['select2d', 'zoomIn2d', 'zoomOut2d', 'autoScale2d'],
                    }}
                    onRelayout={onRelayout}
                    onSelected={onSelected}
                    onClick={() => {
                        setSynthetic(synthetic + 1)
                    }}
                />
            </Box>
        )
    )
}
