import React, { useState, useRef, useEffect } from 'react';
import styles from "./NNVisRegion.module.less"
import commonStyle from '../common/ComponentCommons';
import BaseDragableElement from '../common/BaseDragableElement';
import { IntrinsicElementProps } from '../common/BaseDragableElement';
import NNVisRegionContent from '@/base/ElementData/NNVisRegionContent';
import * as tf from '@tensorflow/tfjs';
import { useMediaQuery } from 'react-responsive';

const NNVisRegion: React.FC<IntrinsicElementProps<NNVisRegionContent>> = ({
    elementData,
    isEditable,
    handleFocusItem,
    handleResize,
    handleDragStop,
    handleDelete,
}) => {
    const isMobile = useMediaQuery({ query: '(max-width: 768px)' });
    const isPad = useMediaQuery({ query: '(min-width: 769px) and (max-width: 1280px)' });
    const isDesktop = useMediaQuery({ query: '(min-width: 1281px)' });

    const [model, setModel] = useState<tf.LayersModel | null>(null);
    const [imageData, setImageData] = useState<ImageData | null>(null);
    const [predictedClass, setPredictedClass] = useState<number | null>(null);
    const [lastLayerOutputs, setLastLayerOutputs] = useState<number[]>(Array(10).fill(0));
    const [secondLastLayerOutputs, setSecondLastLayerOutputs] = useState<number[]>(Array(25).fill(0));
    const [thirdLastLayerOutputs, setThirdLastLayerOutputs] = useState<number[]>(Array(25).fill(0));
    const [pixelColors, setPixelColors] = useState<string[]>(Array(280 * 280).fill(''));

    // 画布相关
    const [imageSrc, setImageSrc] = useState<string | null>(null);
    const canvasRef = useRef<HTMLCanvasElement | null>(null);
    const ctxRef = useRef<CanvasRenderingContext2D | null>(null);
    const drawingRef = useRef(false);

    // 新增ref用于连线Canvas和容器
    const containerRef = useRef<HTMLDivElement>(null);
    const gridOverlayRef = useRef<HTMLDivElement>(null);
    const linesCanvasRef = useRef<HTMLCanvasElement>(null);
    const linesCtxRef = useRef<CanvasRenderingContext2D | null>(null);
    const activeConnectionsRef = useRef<{ layer: string; index: number } | null>(null);

    // 尺寸计算
    const scaleFactor = isPad ? 0.78 : 1;
    const canvasSize = 280 * scaleFactor;

    // 初始化连线Canvas
    useEffect(() => {
        const updateCanvasSize = () => {
            const canvas = linesCanvasRef.current;
            const container = containerRef.current;
            if (canvas && container) {
                const rect = container.getBoundingClientRect();
                canvas.width = rect.width;
                canvas.height = rect.height;
                linesCtxRef.current = canvas.getContext('2d');
            }
        };

        updateCanvasSize();
        const resizeObserver = new ResizeObserver(updateCanvasSize);
        if (containerRef.current) {
            resizeObserver.observe(containerRef.current);
        }
        return () => resizeObserver.disconnect();
    }, [scaleFactor]);

    // 绘制连线通用方法
    const drawConnections = (startPoint: { x: number; y: number }, endPoints: { x: number; y: number }[]) => {
        const ctx = linesCtxRef.current;
        if (!ctx) return;

        ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
        ctx.beginPath();
        ctx.strokeStyle = 'rgba(0, 0, 0, 0.2)';
        ctx.lineWidth = 1;

        endPoints.forEach(end => {
            ctx.moveTo(startPoint.x, startPoint.y);
            ctx.lineTo(end.x, end.y);
        });

        ctx.stroke();
    };

    // 处理节点悬停
    const handleNodeHover = (event: React.MouseEvent<HTMLDivElement>, layer: string, index: number) => {
        if (!predictedClass) return;

        const container = containerRef.current;
        const grid = gridOverlayRef.current;
        if (!container || !grid) return;

        // 获取当前节点位置
        const nodeRect = event.currentTarget.getBoundingClientRect();
        const containerRect = container.getBoundingClientRect();

        // 计算起点（底部中心）
        const startX = nodeRect.left + nodeRect.width / 2 - containerRect.left;
        const startY = nodeRect.bottom - containerRect.top;

        // 计算目标位置
        let endPoints: { x: number; y: number }[] = [];

        if (layer === 'third') {
            // 连接到网格中心
            const gridRect = grid.getBoundingClientRect();
            const cellWidth = gridRect.width / 28;
            const cellHeight = gridRect.height / 28;

            endPoints = Array.from({ length: 28 * 28 }, (_, i) => {
                const row = Math.floor(i / 28);
                const col = i % 28;
                return {
                    x: gridRect.left + col * cellWidth + cellWidth / 2 - containerRect.left,
                    y: gridRect.top + row * cellHeight + cellHeight / 2 - containerRect.top
                };
            });
        }
        else if (layer === 'second') {
            // 连接到第三层顶部中心
            const thirdLayerNodes = document.querySelectorAll(`.${styles.output3} .${styles.output3Item}`);
            thirdLayerNodes.forEach(node => {
                const rect = node.getBoundingClientRect();
                endPoints.push({
                    x: rect.left + rect.width / 2 - containerRect.left,
                    y: rect.top - containerRect.top
                });
            });
        }
        else if (layer === 'first') {
            // 连接到第二层顶部中心
            const secondLayerNodes = document.querySelectorAll(`.${styles.output2} .${styles.output2Item}`);
            secondLayerNodes.forEach(node => {
                const rect = node.getBoundingClientRect();
                endPoints.push({
                    x: rect.left + rect.width / 2 - containerRect.left,
                    y: rect.top - containerRect.top
                });
            });
        }

        activeConnectionsRef.current = { layer, index };
        drawConnections({ x: startX, y: startY }, endPoints);
    };

    // 处理节点离开
    const handleNodeLeave = () => {
        activeConnectionsRef.current = null;
        linesCtxRef.current?.clearRect(0, 0, linesCanvasRef.current?.width || 0, linesCanvasRef.current?.height || 0);
    };

    useEffect(() => {
        const loadModel = async () => {
            const loadedModel = await tf.loadLayersModel('https://nf-internal.oss-cn-beijing.aliyuncs.com/model_l5/model.json');
            setModel(loadedModel);
            console.log('模型加载成功！');
        };

        loadModel().catch(err => console.error('加载模型时发生错误：', err));
    }, []);

    const predict = async () => {
        if (!imageSrc || !model) return;

        const img = new Image();
        img.src = imageSrc;
        await new Promise((resolve, reject) => {
            img.onload = () => {
                resolve(img);
            };
            img.onerror = reject;
        });

        // 使用 tf.browser.fromPixels 将图像转换为张量
        const imgTensor = tf.browser.fromPixels(img);

        // 转换为灰度图像
        const grayTensor = tf.mean(imgTensor, -1, true);

        // 调整尺寸为 28x28
        const resizedTensor = tf.image.resizeBilinear(grayTensor, [28, 28]);

        // 归一化
        const normalizedTensor = resizedTensor.div(255);

        // 调整张量形状以匹配模型输入
        const inputTensor = normalizedTensor.reshape([1, 28, 28]);


        // const inputTensor = tf.tensor(grayData, [1, 28, 28], 'float32').div(255);
        const prediction = model.predict(inputTensor) as tf.Tensor;
        const predictionArray = await prediction.array();
        const predictedIndex = predictionArray[0].indexOf(Math.max(...predictionArray[0]));
        setPredictedClass(predictedIndex);

        const layerOutputs = model.layers.map(layer => layer.output);
        const intermediateModel = tf.model({ inputs: model.inputs, outputs: layerOutputs });
        const intermediateOutputs = intermediateModel.predict(inputTensor) as tf.Tensor[];

        const lastLayerOutput = intermediateOutputs[intermediateOutputs.length - 1];
        const secondLastLayerOutput = intermediateOutputs[intermediateOutputs.length - 2];
        const thirdLastLayerOutput = intermediateOutputs[intermediateOutputs.length - 3];

        const lastLayerOutputArray = await lastLayerOutput.array();
        const secondLastLayerOutputArray = await secondLastLayerOutput.array();
        const thirdLastLayerOutputArray = await thirdLastLayerOutput.array();

        setLastLayerOutputs(lastLayerOutputArray[0]);
        setSecondLastLayerOutputs(secondLastLayerOutputArray[0]);
        setThirdLastLayerOutputs(thirdLastLayerOutputArray[0]);

        inputTensor.dispose();
        prediction.dispose();
        intermediateOutputs.forEach(output => output.dispose());
    };

    const resetData = () => {
        setPredictedClass(null);
        setLastLayerOutputs(Array(10).fill(0));
        setSecondLastLayerOutputs(Array(25).fill(0));
        setThirdLastLayerOutputs(Array(25).fill(0));
        setPixelColors(Array(280 * 280).fill(''));
    }

    // 计算百分比并返回数组  
    const calculatePercentages = (outputs: number[]) => {
        return outputs.map(output => output * 100);
    };

    // 获取每一层的百分比  
    const lastLayerPercentages = calculatePercentages(lastLayerOutputs);
    const secondLastLayerPercentages = calculatePercentages(secondLastLayerOutputs);
    const thirdLastLayerPercentages = calculatePercentages(thirdLastLayerOutputs);

    useEffect(() => {
        const canvas = canvasRef.current;
        if (canvas) {
            const ctx = canvas.getContext('2d', { willReadFrequently: true });
            ctxRef.current = ctx;
            ctx.fillStyle = 'white';
            ctx.fillRect(0, 0, canvas.width, canvas.height);
        }
    }, []);

    const startDrawing = () => {
        drawingRef.current = true;
        ctxRef.current?.beginPath();
    };

    const endDrawing = () => {
        drawingRef.current = false;
        ctxRef.current?.beginPath();
    };

    // 更新网格颜色的函数
    const updateGridColors = () => {
        const canvas = canvasRef.current;
        if (canvas) {
            const imageData = ctxRef.current?.getImageData(0, 0, canvas.width, canvas.height);
            if (imageData) {
                const scaledCanvas = document.createElement('canvas');
                scaledCanvas.width = 28;
                scaledCanvas.height = 28;
                const scaledContext = scaledCanvas.getContext('2d');
                if (scaledContext) {
                    scaledContext.drawImage(
                        canvas,
                        0, 0,
                        canvas.width, canvas.height,
                        0, 0,
                        scaledCanvas.width, scaledCanvas.height
                    );
                    const scaledImageData = scaledContext.getImageData(0, 0, scaledCanvas.width, scaledCanvas.height);
                    // 灰度反转处理
                    const data = scaledImageData.data;
                    for (let i = 0; i < data.length; i += 4) {
                        data[i] = 255 - data[i];
                        data[i + 1] = 255 - data[i + 1];
                        data[i + 2] = 255 - data[i + 2];
                    }
                    scaledContext.putImageData(scaledImageData, 0, 0);
                    setImageData(scaledImageData);
                    const dataUrl = scaledCanvas.toDataURL();
                    setImageSrc(dataUrl);
                    const newPixelColors = [];
                    for (let i = 0; i < data.length; i += 4) {
                        const r = data[i];
                        const g = data[i + 1];
                        const b = data[i + 2];
                        const grayValue = (0.299 * r + 0.587 * g + 0.114 * b);
                        const percentage = (grayValue / 255) * 100;
                        const color = getGridColorForPercentage(percentage);
                        newPixelColors.push(color);
                    }
                    setPixelColors(newPixelColors);
                }
            }
        }
    };
    const draw = (event: React.MouseEvent<HTMLDivElement>) => {
        if (!drawingRef.current || !ctxRef.current) return;
        const canvas = canvasRef.current;
        if (canvas) {
            ctxRef.current.lineWidth = 20; // Adjust lineWidth if needed  
            ctxRef.current.lineCap = 'round';
            ctxRef.current.strokeStyle = 'black';
            const rect = canvas.getBoundingClientRect();
            const scaleX = canvas.width / rect.width;
            const scaleY = canvas.height / rect.height;
            const x = (event.clientX - rect.left) * scaleX;
            const y = (event.clientY - rect.top) * scaleY;
            // 检查鼠标是否移出 canvas 边界
            if (x < 0 || x > canvas.width || y < 0 || y > canvas.height) {
                endDrawing();
                return;
            }
            ctxRef.current.lineTo(x, y);
            ctxRef.current.stroke();
            ctxRef.current.beginPath();
            ctxRef.current.moveTo(x, y);
            // 实时更新网格颜色
            updateGridColors();
        }
    };
    const drawTouch = (event: React.TouchEvent<HTMLDivElement>) => {
        if (!drawingRef.current || !ctxRef.current) return;

        const canvas = canvasRef.current;
        if (canvas) {
            ctxRef.current.lineWidth = 20; // 根据需要调整线宽  
            ctxRef.current.lineCap = 'round';
            ctxRef.current.strokeStyle = 'black';
            const rect = canvas.getBoundingClientRect();
            // 获取触摸点  
            const touch = event.touches[0];
            const scaleX = canvas.width / rect.width;
            const scaleY = canvas.height / rect.height;
            const x = (touch.clientX - rect.left) * scaleX;
            const y = (touch.clientY - rect.top) * scaleY;
            // 检查触摸点是否移出 canvas 边界
            if (x < 0 || x > canvas.width || y < 0 || y > canvas.height) {
                endDrawing();
                return;
            }

            ctxRef.current.lineTo(x, y);
            ctxRef.current.stroke();
            ctxRef.current.beginPath();
            ctxRef.current.moveTo(x, y);
            // 实时更新网格颜色
            updateGridColors();
        }
    };
    const clearCanvas = () => {
        const canvas = canvasRef.current;
        if (canvas && ctxRef.current) {
            ctxRef.current.fillStyle = 'white';
            ctxRef.current.fillRect(0, 0, canvas.width, canvas.height);
            setImageData(null);
            setImageSrc(null);
        }
        resetData();
    };
    const handleMouseUp = () => {
        endDrawing();
        updateGridColors();
    };

    const getColorForPercentage = (percentage: number) => {
        const colors = [
            'rgba(200, 200, 200, 1)', // 0% - 20% (浅灰色)  
            'rgba(150, 150, 150, 1)', // 20% - 40% (中灰色)  
            'rgba(100, 100, 100, 1)', // 40% - 60%  
            'rgba(50, 50, 50, 1)',    // 60% - 80%  
            'rgba(0, 0, 0, 1)'        // 80% - 100% (黑色)  
        ];
        if (percentage < 20) return colors[0];
        if (percentage < 40) return colors[1];
        if (percentage < 60) return colors[2];
        if (percentage < 80) return colors[3];
        return colors[4];
    };

    const getGridColorForPercentage = (percentage: number) => {
        const colors = [
            'rgba(0, 0, 0, 0.1)', // 0% - 20% (浅灰色)  
            'rgba(150, 150, 150, 1)', // 20% - 40% (中灰色)  
            'rgba(100, 100, 100, 1)', // 40% - 60%  
            'rgba(50, 50, 50, 1)',    // 60% - 80%  
            'rgba(0, 0, 0, 1)'        // 80% - 100% (黑色)  
        ];
        if (percentage < 20) return colors[0];
        if (percentage < 40) return colors[1];
        if (percentage < 60) return colors[2];
        if (percentage < 80) return colors[3];
        return colors[4];
    };

    useEffect(() => {
        // 计算每个像素的填充百分比并设置颜色
        if (!imageData) return;
        const data = imageData.data;
        const newPixelColors = [];
        for (let i = 0; i < data.length; i += 4) {
            const r = data[i];
            const g = data[i + 1];
            const b = data[i + 2];
            const grayValue = (0.299 * r + 0.587 * g + 0.114 * b);
            const percentage = (grayValue / 255) * 100;
            const color = getGridColorForPercentage(percentage);
            newPixelColors.push(color);
        }
        setPixelColors(newPixelColors);
    }, [imageData]);

    return (
        <BaseDragableElement
            elementData={elementData}
            isEditable={isEditable}
            handleFocusItem={handleFocusItem}
            handleResize={handleResize}
            handleDragStop={handleDragStop}
            handleDelete={handleDelete}
        >
            <div
                ref={containerRef}
                style={{ ...commonStyle, position: 'relative' }}
                onClick={e => { if (isEditable) handleFocusItem(elementData, e); }}
                className={`${elementData.isFocus && isEditable ? styles.elementFocused : ''} ${isEditable ? styles.element : ''}`}
            >
                <canvas
                    ref={linesCanvasRef}
                    style={{
                        position: 'absolute',
                        top: 0,
                        left: 0,
                        pointerEvents: 'none',
                        zIndex: 999
                    }}
                />
                <div className={styles.nnVisBox}>
                    <div className={styles.predictionBox}>
                        <div className={styles.outputBox}>
                            {/* 第一层 */}
                            <div className={styles.output1}>
                                {lastLayerPercentages.map((percentage, index) => (
                                    <div
                                        key={index}
                                        className={styles.output1ItemBox}
                                    >
                                        <div className={styles.output1ItemIndex}>
                                            {index}
                                        </div>
                                        <div
                                            className={styles.output1Item}
                                            onMouseEnter={(e) => handleNodeHover(e, 'first', index)}
                                            onMouseLeave={handleNodeLeave}
                                        >
                                            <div
                                                style={{
                                                    height: `${percentage}%`,
                                                    backgroundColor: getColorForPercentage(percentage),
                                                    maxHeight: '100%'
                                                }}
                                                className={styles.outputColorArea}
                                            />
                                        </div>
                                    </div>
                                ))}
                            </div>
                            {/* 第二层 */}
                            <div className={styles.output2}>
                                {secondLastLayerPercentages.map((percentage, index) => (
                                    <div
                                        key={index}
                                        className={styles.output2Item}
                                        onMouseEnter={(e) => handleNodeHover(e, 'second', index)}
                                        onMouseLeave={handleNodeLeave}
                                    >
                                        <div
                                            style={{
                                                height: `${percentage}%`,
                                                backgroundColor: getColorForPercentage(percentage),
                                                maxHeight: '100%'
                                            }}
                                            className={styles.outputColorArea}
                                        />
                                    </div>
                                ))}
                            </div>
                            {/* 第三层 */}
                            <div className={styles.output3}>
                                {thirdLastLayerPercentages.map((percentage, index) => (
                                    <div
                                        key={index}
                                        className={styles.output3Item}
                                        onMouseEnter={(e) => handleNodeHover(e, 'third', index)}
                                        onMouseLeave={handleNodeLeave}
                                    >
                                        <div
                                            style={{
                                                height: `${percentage}%`,
                                                backgroundColor: getColorForPercentage(percentage),
                                                maxHeight: '100%'
                                            }}
                                            className={styles.outputColorArea}
                                        />
                                    </div>
                                ))}
                            </div>
                        </div>
                        {/* 画布区域 */}
                        <div className={styles.canvasBox}>
                            <div className={styles.canvasContainer}>
                                <canvas
                                    id="canvas"
                                    width={canvasSize}
                                    height={canvasSize}
                                    ref={canvasRef}
                                    className={styles.canvas}
                                    style={{
                                        width: canvasSize,
                                        height: canvasSize
                                    }}
                                />
                                <div
                                    className={styles.gridOverlay}
                                    ref={gridOverlayRef}
                                    style={{
                                        width: canvasSize,
                                        height: canvasSize,
                                        gridTemplateColumns: `repeat(28, ${canvasSize / 28}px)`,
                                        touchAction: 'none' // 阻止默认触摸行为
                                    }}
                                    onMouseDown={startDrawing}
                                    onTouchStart={startDrawing} // 添加 Touch 事件  
                                    onMouseUp={handleMouseUp}
                                    onTouchEnd={handleMouseUp} // 添加 Touch 结束事件  
                                    onMouseMove={e => draw(e)}
                                    onTouchMove={e => drawTouch(e)} // 添加触摸移动事件  
                                >
                                    {/* 28x28 grid layout */}
                                    {Array.from({ length: 28 * 28 }, (_, index) => (
                                        <div
                                            key={index}
                                            className={styles.gridCell}
                                            style={{ backgroundColor: pixelColors[index] }}
                                        />
                                    ))}
                                </div>
                            </div>
                            <div className={styles.buttonContainer}>
                                <button onClick={predict} className={styles.button}>
                                    识别
                                </button>
                                <button onClick={clearCanvas} className={styles.button}>
                                    清空
                                </button>
                            </div>
                        </div>
                    </div>
                </div>
            </div>
        </BaseDragableElement>
    );
};

export default NNVisRegion;