import CloseIcon from "../assets/close.svg";
import { useRef, useEffect, useState } from "react";
import { PiPaintBrushLight, PiEraserLight, PiSquare, PiCircle, PiCircleFill, PiDotOutlineFill, PiArrowsInSimpleLight } from "react-icons/pi";
import { SlActionUndo } from "react-icons/sl";
import { SlActionRedo } from "react-icons/sl";
import { RxReset } from "react-icons/rx";
import { CgDropInvert } from "react-icons/cg";
import { Tooltip } from "flowbite-react";
import axios from "axios";
import ToastErrorMessage from "./ToastErrorMessage";
import Spinner from "./Spinner";
import { useSelector } from "react-redux";
import { selectAccessToken, selectAuthenticated } from  "../selectors";

export default function MaskEditor({image, setShowMaskEditor, setBinaryMaskImage}) {
    const canvasRef = useRef(null);
    const ctxRef = useRef(null);
    const isDrawingEnabled = useRef(false);
    const isDrawing = useRef(false);
    const [drawColor, setDrawColor] = useState('rgba(255, 255, 255, 1)');
    const [eraserColor, setEraserColor] = useState('rgba(225,29,72, 0.5)');
    const [fillStyleColor, setFillStyleColor] = useState('rgba(255, 255, 255, 1)');
    const [isRectangleMode, setIsRectangleMode] = useState(false);
    const [startPoint, setStartPoint] = useState(null);
    const [currentRect, setCurrentRect] = useState(null);

    const [isDrawingMode, setIsDrawingMode] = useState(false);
    const [isPointMode, setIsPointMode] = useState(false);
    const [isEraserMode, setIsEraserMode] = useState(false);
    const [canvasStates, setCanvasStates] = useState([]);
    const [redoStates, setRedoStates] = useState([]);
    const [drawingPoints, setDrawingPoints] = useState([]);
    const [lineWidth, setLineWidth] = useState(12);
    const [currentImage, setCurrentImage] = useState(image);
    const [brushPoints, setBrushPoints] = useState([]);
    const [eraserPoints, setEraserPoints] = useState([]);
    const [prompt, setPrompt] = useState('');
    const [errorMessage, setErrorMessage] = useState('');
    const [showErrorMessage, setShowErrorMessage] = useState(false);

    const [isGeneratingMask, setIsGeneratingMask] = useState(false);
    const [isMaskGenerated, setIsMaskGenerated] = useState(false);
    const accessToken = useSelector(selectAccessToken) || null;
    const authenticated = useSelector(selectAuthenticated);
    const [generatedMask, setGeneratedMask] = useState(null);

    // Add state for image dimensions
    const [imageDimensions, setImageDimensions] = useState({ width: 0, height: 0 });

    const [redoPoints, setRedoPoints] = useState([]);
    const [isResizing, setIsResizing] = useState(false);
    const [resizeHandle, setResizeHandle] = useState(null);
    const [isDragging, setIsDragging] = useState(false);
    const [dragOffset, setDragOffset] = useState({ x: 0, y: 0 });
    const HANDLE_SIZE = 8; // Size of resize handles in pixels

    // Add new mode enum after the state declarations
    const DrawingMode = {
        NONE: 'none',
        STROKE: 'stroke',
        POINT: 'point',
        ERASER: 'eraser',
        RECTANGLE: 'rectangle'
    };

    const [currentMode, setCurrentMode] = useState(DrawingMode.NONE);

    // Add after DrawingMode enum
    const PointSubMode = {
        POSITIVE: 'positive',
        NEGATIVE: 'negative'
    };

    // Add new state variables after other state declarations
    const [pointSubMode, setPointSubMode] = useState(null);
    const [showPointSubmodeMenu, setShowPointSubmodeMenu] = useState(false);

    // Add after other state declarations
    const dropdownRef = useRef(null);

    // Add after the state declarations
    const BASE_LINE_WIDTH = 12; // Base width for a reference resolution
    const REFERENCE_WIDTH = 800; // Reference image width

    // Add this helper function after the state declarations
    const getScaledLineWidth = () => {
        if (!imageDimensions.width) return BASE_LINE_WIDTH;
        // Scale line width based on image width, with min and max constraints
        const scaleFactor = imageDimensions.width / REFERENCE_WIDTH;
        const scaledWidth = BASE_LINE_WIDTH * scaleFactor;
        return Math.max(6, Math.min(scaledWidth, 24)); // Constrain between 6 and 24 pixels
    };

    // Add this component at the top with other imports
    const PointIcon = ({ isPositive }) => (
        <div className="relative flex items-center justify-center w-5 h-5">
            <div className={`absolute w-4 h-4 rounded-full ${isPositive ? 'bg-green-500' : 'bg-red-500'} opacity-80`}></div>
            <div className={`absolute w-2 h-2 rounded-full ${isPositive ? 'bg-green-300' : 'bg-red-300'}`}></div>
        </div>
    );

    // Add new state variables for positive and negative points
    const [positivePoints, setPositivePoints] = useState([]);
    const [negativePoints, setNegativePoints] = useState([]);

    const [isMaskInverted, setIsMaskInverted] = useState(false);

    useEffect(() => {
        console.log("currentImage", currentImage);
        const canvas = canvasRef.current;
        if (canvas) {
            console.log("canvas", canvas);
            canvas.width = canvas.offsetWidth;
            canvas.height = canvas.offsetHeight;
            ctxRef.current = canvas.getContext('2d');
            ctxRef.current.lineWidth = lineWidth;
            ctxRef.current.strokeStyle = drawColor;
            ctxRef.current.fillStyle = fillStyleColor;
            ctxRef.current.lineCap = 'round';
        }
    }, [currentImage]);

    // Add/remove no-scroll class to body when modal is mounted/unmounted
    useEffect(() => {
        document.body.classList.add('overflow-hidden');
        
        let resizeObserver;
        let animationFrameId;
        let lastWidth = 0;
        let lastHeight = 0;
        
        const handleResize = (entries) => {
            // Cancel any pending animation frame
            if (animationFrameId) {
                cancelAnimationFrame(animationFrameId);
            }
            
            animationFrameId = requestAnimationFrame(() => {
                if (!canvasRef.current || !ctxRef.current) return;
                
                const canvas = canvasRef.current;
                const newWidth = Math.max(canvas.offsetWidth, 1); // Ensure minimum width of 1
                const newHeight = Math.max(canvas.offsetHeight, 1); // Ensure minimum height of 1
                
                // Only update if dimensions actually changed
                if (newWidth === lastWidth && newHeight === lastHeight) return;
                
                // Store current canvas content
                const tempCanvas = document.createElement('canvas');
                const tempCtx = tempCanvas.getContext('2d');
                tempCanvas.width = Math.max(canvas.width, 1);
                tempCanvas.height = Math.max(canvas.height, 1);
                
                // Only draw if source canvas has valid dimensions
                if (canvas.width > 0 && canvas.height > 0) {
                    tempCtx.drawImage(canvas, 0, 0);
                }
                
                // Update canvas dimensions
                canvas.width = newWidth;
                canvas.height = newHeight;
                
                // Restore context properties
                const ctx = ctxRef.current;
                ctx.lineWidth = lineWidth;
                ctx.strokeStyle = drawColor;
                ctx.fillStyle = fillStyleColor;
                ctx.lineCap = 'round';
                
                // Only restore if temp canvas has valid content
                if (tempCanvas.width > 0 && tempCanvas.height > 0) {
                    ctx.drawImage(tempCanvas, 0, 0, newWidth, newHeight);
                }
                
                // Update last dimensions
                lastWidth = newWidth;
                lastHeight = newHeight;
            });
        };

        if (canvasRef.current) {
            resizeObserver = new ResizeObserver(handleResize);
            resizeObserver.observe(canvasRef.current);
            
            // Initial size setup
            const initialWidth = Math.max(canvasRef.current.offsetWidth, 1);
            const initialHeight = Math.max(canvasRef.current.offsetHeight, 1);
            canvasRef.current.width = initialWidth;
            canvasRef.current.height = initialHeight;
            lastWidth = initialWidth;
            lastHeight = initialHeight;
        }
        
        return () => {
            document.body.classList.remove('overflow-hidden');
            if (resizeObserver) {
                resizeObserver.disconnect();
            }
            if (animationFrameId) {
                cancelAnimationFrame(animationFrameId);
            }
        };
    }, []);

    useEffect(() => {
        resetAll();
    }, []);

    // Update useEffect for image loading
    useEffect(() => {
        const img = new Image();
        img.src = currentImage;
        img.onload = () => {
            if (canvasRef.current) {
                canvasRef.current.width = img.width;
                canvasRef.current.height = img.height;
                setImageDimensions({
                    width: img.width,
                    height: img.height
                });
                
                // Set up canvas context with scaled line width
                const ctx = canvasRef.current.getContext('2d');
                const scaledWidth = getScaledLineWidth();
                setLineWidth(scaledWidth);
                ctx.lineWidth = scaledWidth;
                ctx.strokeStyle = drawColor;
                ctx.fillStyle = fillStyleColor;
            }
        };
    }, [currentImage]);

    // Add after other useEffect hooks
    useEffect(() => {
        const handleClickOutside = (event) => {
            if (dropdownRef.current && !dropdownRef.current.contains(event.target)) {
                setShowPointSubmodeMenu(false);
            }
        };

        if (showPointSubmodeMenu) {
            document.addEventListener('mousedown', handleClickOutside);
        }

        return () => {
            document.removeEventListener('mousedown', handleClickOutside);
        };
    }, [showPointSubmodeMenu]);

    const resetAll = (resetBinaryMaskImage = true) => {
        // Reset all states to their initial values
        setBrushPoints([]);
        setEraserPoints([]);
        setDrawingPoints([]);
        setPositivePoints([]);
        setNegativePoints([]);
        setCanvasStates([]);
        setRedoStates([]);
        setRedoPoints([]);
        setPrompt('');
        setErrorMessage('');
        setShowErrorMessage(false);
        setCurrentImage(image);
        isDrawingEnabled.current = false;
        isDrawing.current = false;
        setIsDrawingMode(false);
        setIsPointMode(false);
        setIsEraserMode(false);
        setIsRectangleMode(false);
        setIsMaskGenerated(false);
        setIsGeneratingMask(false);
        if (resetBinaryMaskImage) {
            setBinaryMaskImage('');
        }
        setCurrentMode(DrawingMode.NONE);
        setCurrentRect(null);
        setStartPoint(null);
        setIsResizing(false);
        setResizeHandle(null);
        setIsDragging(false);
        setDragOffset({ x: 0, y: 0 });
        setShowPointSubmodeMenu(false);
        setPointSubMode(null);

        // Clear the canvas if it exists
        const canvas = canvasRef.current;
        const ctx = ctxRef.current;
        if (canvas && ctx) {
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            ctx.globalCompositeOperation = 'source-over';
            ctx.strokeStyle = drawColor;
            ctx.fillStyle = fillStyleColor;
            const scaledWidth = getScaledLineWidth();
            setLineWidth(scaledWidth);
            ctx.lineWidth = scaledWidth;
            ctx.lineCap = 'round';
            canvas.style.cursor = 'default';
        }

        // Save the cleared state
        saveCanvasState();
    };

    // Add cleanup function after resetAll
    const cleanupOnClose = (resetBinaryMaskImage = true) => {
        // Reset all states and values
        resetAll(resetBinaryMaskImage);
        
        // Clear the canvas
        if (canvasRef.current) {
            const ctx = canvasRef.current.getContext('2d');
            ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
        }
        
        // Reset all mode states
        setIsDrawingMode(false);
        setIsPointMode(false);
        setIsEraserMode(false);
        setIsRectangleMode(false);
        setIsMaskGenerated(false);
        setIsMaskInverted(false);
        
        // Reset all points and drawing states
        setBrushPoints([]);
        setEraserPoints([]);
        setDrawingPoints([]);
        setPositivePoints([]);
        setNegativePoints([]);
        
        // Reset canvas states
        setCanvasStates([]);
        setRedoStates([]);
        setRedoPoints([]);
        
        // Reset rectangle states
        setCurrentRect(null);
        setStartPoint(null);
        setIsResizing(false);
        setResizeHandle(null);
        setIsDragging(false);
        setDragOffset({ x: 0, y: 0 });
        
        // Reset other UI states
        setShowPointSubmodeMenu(false);
        setPointSubMode(null);
        setPrompt('');
        setErrorMessage('');
        setShowErrorMessage(false);
        
        // Reset refs
        isDrawingEnabled.current = false;
        isDrawing.current = false;
    };

    const enableBrushDrawing = () => {
        isDrawingEnabled.current = true;
        setCurrentMode(DrawingMode.STROKE);
        if (ctxRef.current) {
            ctxRef.current.globalCompositeOperation = 'source-over';
            ctxRef.current.strokeStyle = drawColor;
            ctxRef.current.fillStyle = fillStyleColor;
            ctxRef.current.lineWidth = getScaledLineWidth();
            ctxRef.current.lineCap = 'round';
        }
    };

    // Add this helper function after the drawPoint function
    const drawPointWithType = (ctx, x, y, isPositive) => {
        // Get the canvas and its bounding rectangle
        const canvas = canvasRef.current;
        const rect = canvas.getBoundingClientRect();
        
        // Calculate the scale factor between canvas coordinates and viewport coordinates
        const scaleX = canvas.width / rect.width;
        const scaleY = canvas.height / rect.height;
        
        // Set fixed viewport sizes (in pixels)
        const VIEWPORT_OUTER_RADIUS = 4;
        const VIEWPORT_INNER_RADIUS = 2;  
        
        // Convert viewport sizes to canvas coordinates
        const outerRadius = VIEWPORT_OUTER_RADIUS * scaleX;
        const innerRadius = VIEWPORT_INNER_RADIUS * scaleX;

        // Colors for positive points
        const positiveOuterColor = 'rgba(0, 255, 0, 0.6)';  // Slightly more transparent green
        const positiveInnerColor = 'rgba(144, 238, 144, 0.8)'; // Light green with some transparency

        // Colors for negative points
        const negativeOuterColor = 'rgba(255, 0, 0, 0.6)';  // Slightly more transparent red
        const negativeInnerColor = 'rgba(255, 182, 193, 0.8)'; // Light red with some transparency

        // Draw outer circle
        ctx.beginPath();
        ctx.fillStyle = isPositive ? positiveOuterColor : negativeOuterColor;
        ctx.arc(x, y, outerRadius, 0, Math.PI * 2);
        ctx.fill();

        // Draw inner circle
        ctx.beginPath();
        ctx.fillStyle = isPositive ? positiveInnerColor : negativeInnerColor;
        ctx.arc(x, y, innerRadius, 0, Math.PI * 2);
        ctx.fill();
    };

    const enablePointDrawing = () => {
        isDrawingEnabled.current = true;
        setCurrentMode(DrawingMode.POINT);
        
        // Don't clear existing strokes when switching to point mode
        // We'll clear them when a point is actually drawn
        
        if (ctxRef.current) {
            ctxRef.current.globalCompositeOperation = 'source-over';
            ctxRef.current.strokeStyle = drawColor;
            ctxRef.current.fillStyle = fillStyleColor;
            ctxRef.current.lineWidth = getScaledLineWidth();
            ctxRef.current.lineCap = 'round';
        }
    };

    const enableEraser = () => {
        isDrawingEnabled.current = true;
        setCurrentMode(DrawingMode.ERASER);
        if (ctxRef.current) {
            ctxRef.current.globalCompositeOperation = 'destination-out';
            ctxRef.current.strokeStyle = 'rgba(0, 0, 0, 1)';
            ctxRef.current.lineWidth = getScaledLineWidth();
            ctxRef.current.lineCap = 'round';
        }
    };

    const saveCanvasState = () => {
        const canvas = canvasRef.current;
        if (canvas) {
            const state = canvas.toDataURL();
            setCanvasStates(prev => [...prev, state]);
            setRedoStates([]); // Clear redo stack when new drawing is made
            setRedoPoints([]); // Also clear redo points
        }
    };

    const drawPoint = (ctx, x, y) => {
        // Get the canvas and its bounding rectangle
        const canvas = canvasRef.current;
        const rect = canvas.getBoundingClientRect();
        
        // Calculate the scale factor between canvas coordinates and viewport coordinates
        const scaleX = canvas.width / rect.width;
        const scaleY = canvas.height / rect.height;
        
        // Set fixed viewport sizes (in pixels)
        const VIEWPORT_OUTER_RADIUS = 4;
        const VIEWPORT_INNER_RADIUS = 2;  
        
        // Convert viewport sizes to canvas coordinates
        const outerRadius = VIEWPORT_OUTER_RADIUS * scaleX;
        const innerRadius = VIEWPORT_INNER_RADIUS * scaleX;
        
        if (currentMode === DrawingMode.ERASER) {
            ctx.beginPath();
            ctx.fillStyle = eraserColor;
            ctx.arc(x, y, outerRadius, 0, Math.PI * 2);
            ctx.fill();
            return;
        }

        // Colors for positive points
        const positiveOuterColor = 'rgba(0, 255, 0, 0.6)';  // Slightly more transparent green
        const positiveInnerColor = 'rgba(144, 238, 144, 0.8)'; // Light green with some transparency

        // Colors for negative points
        const negativeOuterColor = 'rgba(255, 0, 0, 0.6)';  // Slightly more transparent red
        const negativeInnerColor = 'rgba(255, 182, 193, 0.8)'; // Light red with some transparency

        // Draw outer circle
        ctx.beginPath();
        ctx.fillStyle = pointSubMode === PointSubMode.POSITIVE ? positiveOuterColor : negativeOuterColor;
        ctx.arc(x, y, outerRadius, 0, Math.PI * 2);
        ctx.fill();

        // Draw inner circle
        ctx.beginPath();
        ctx.fillStyle = pointSubMode === PointSubMode.POSITIVE ? positiveInnerColor : negativeInnerColor;
        ctx.arc(x, y, innerRadius, 0, Math.PI * 2);
        ctx.fill();
    };

    // Add helper function to check if a point is inside rectangle
    const isPointInRectangle = (x, y, rect) => {
        return x >= rect.x && x <= rect.x + rect.width &&
               y >= rect.y && y <= rect.y + rect.height;
    };

    const handleMouseDown = (e) => {
        if (!isDrawingEnabled.current) return;
        
        const rect = canvasRef.current.getBoundingClientRect();
        const canvas = canvasRef.current;
        
        // Calculate the scaling factors
        const scaleX = canvas.width / rect.width;
        const scaleY = canvas.height / rect.height;
        
        // Calculate the scaled coordinates
        const x = (e.clientX - rect.left) * scaleX;
        const y = (e.clientY - rect.top) * scaleY;
        
        const ctx = ctxRef.current;

        // Check if we're clicking the delete button for a rectangle
        if (currentMode === DrawingMode.RECTANGLE && currentRect) {
            const buttonSize = getButtonSize(canvas);
            const buttonX = currentRect.x + currentRect.width / 2 - buttonSize / 2;
            const buttonY = currentRect.y - buttonSize - 5;
            
            const dx = x - (buttonX + buttonSize / 2);
            const dy = y - (buttonY + buttonSize / 2);
            
            if (Math.sqrt(dx * dx + dy * dy) <= buttonSize / 2) {
                // We're clicking on the delete button, so delete the rectangle
                deleteCurrentRectangle();
                e.preventDefault();
                e.stopPropagation();
                return;
            }
        }

        // If there's a rectangle, only allow rectangle-related actions
        if (currentRect && currentMode !== DrawingMode.RECTANGLE) {
            return;
        }

        // Check for rectangle resize or drag operations
        if (currentMode === DrawingMode.RECTANGLE && currentRect) {
            const handle = getResizeHandle(x, y, currentRect);
            if (handle) {
                setIsResizing(true);
                setResizeHandle(handle);
                return;
            }
            
            if (isPointInRectangle(x, y, currentRect)) {
                setIsDragging(true);
                setDragOffset({
                    x: x - currentRect.x,
                    y: y - currentRect.y
                });
                return;
            }
        }

        if (currentMode === DrawingMode.RECTANGLE) {
            // Clear everything when starting to draw a new rectangle
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            setBrushPoints([]);
            setEraserPoints([]);
            setPositivePoints([]);
            setNegativePoints([]);
            setDrawingPoints([]);
            
            // Start drawing a new rectangle
            isDrawing.current = true;
            setStartPoint({ x, y });
            setCurrentRect({ x, y, width: 10, height: 10 });
        } else if (currentMode === DrawingMode.ERASER) {
            // Start erasing
            isDrawing.current = true;
            ctx.globalCompositeOperation = 'destination-out';
            ctx.lineWidth = getScaledLineWidth();
            ctx.lineCap = 'round';
            ctx.beginPath();
            ctx.moveTo(x, y);
        } else if (currentMode === DrawingMode.STROKE) {
            // Start drawing strokes
            isDrawing.current = true;
            ctx.globalCompositeOperation = 'source-over';
            ctx.lineWidth = getScaledLineWidth();
            ctx.lineCap = 'round';
            ctx.strokeStyle = drawColor;
            ctx.beginPath();
            ctx.moveTo(x, y);
            setBrushPoints(prev => [...prev, { x, y }]);
        } else if (currentMode === DrawingMode.POINT) {
            // Clear existing strokes when clicking in point mode
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            setBrushPoints([]);
            setEraserPoints([]);
            
            // Draw the point
            drawPoint(ctx, x, y);
            if (pointSubMode === PointSubMode.POSITIVE) {
                setPositivePoints(prev => [...prev, { x, y }]);
            } else if (pointSubMode === PointSubMode.NEGATIVE) {
                setNegativePoints(prev => [...prev, { x, y }]);
            }
            
            // Redraw all existing points to ensure they remain visible
            positivePoints.forEach(point => {
                drawPointWithType(ctx, point.x, point.y, true);
            });
            
            negativePoints.forEach(point => {
                drawPointWithType(ctx, point.x, point.y, false);
            });
            
            saveCanvasState();
        }
    };

    const handleMouseMove = (e) => {
        if (!isDrawingEnabled.current) return;

        const rect = canvasRef.current.getBoundingClientRect();
        const canvas = canvasRef.current;
        
        const scaleX = canvas.width / rect.width;
        const scaleY = canvas.height / rect.height;
        
        const x = (e.clientX - rect.left) * scaleX;
        const y = (e.clientY - rect.top) * scaleY;
        
        const ctx = ctxRef.current;

        // Check if mouse is over the delete button for a rectangle
        if (currentMode === DrawingMode.RECTANGLE && currentRect) {
            const buttonSize = getButtonSize(canvas);
            const buttonX = currentRect.x + currentRect.width / 2 - buttonSize / 2;
            const buttonY = currentRect.y - buttonSize - 5;
            
            const dx = x - (buttonX + buttonSize / 2);
            const dy = y - (buttonY + buttonSize / 2);
            
            if (Math.sqrt(dx * dx + dy * dy) <= buttonSize / 2) {
                // Mouse is over the delete button, change cursor to pointer
                canvas.style.cursor = 'pointer';
                return;
            } else {
                // Reset cursor if not over the delete button
                canvas.style.cursor = isResizing ? 'nwse-resize' : isDragging ? 'move' : 'crosshair';
            }
        }

        // Handle rectangle resize
        if (isResizing && currentRect && resizeHandle) {
            const newRect = updateRectangleOnResize(x, y);
            if (newRect) {
                // Clear canvas and redraw the rectangle
                ctx.clearRect(0, 0, canvas.width, canvas.height);
                
                // Add transparent white fill
                ctx.fillStyle = 'rgba(255, 255, 255, 0.2)';
                ctx.fillRect(newRect.x, newRect.y, newRect.width, newRect.height);
                
                // Draw the stroke
                ctx.strokeStyle = drawColor;
                ctx.lineWidth = Math.max(2, getScaledLineWidth() / 2); // Thinner line for rectangles
                ctx.strokeRect(newRect.x, newRect.y, newRect.width, newRect.height);
                drawResizeHandles(ctx, newRect);
                drawDeleteButton(ctx, newRect);
                setCurrentRect(newRect);
            }
            return;
        }

        // Handle rectangle drag
        if (isDragging && currentRect) {
            const newX = x - dragOffset.x;
            const newY = y - dragOffset.y;
            const newRect = {
                ...currentRect,
                x: newX,
                y: newY
            };
            
            // Clear canvas and redraw the rectangle
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            
            // Add transparent white fill
            ctx.fillStyle = 'rgba(255, 255, 255, 0.2)';
            ctx.fillRect(newRect.x, newRect.y, newRect.width, newRect.height);
            
            // Draw the stroke
            ctx.strokeStyle = drawColor;
            ctx.lineWidth = Math.max(2, getScaledLineWidth() / 2); // Thinner line for rectangles
            ctx.strokeRect(newRect.x, newRect.y, newRect.width, newRect.height);
            drawResizeHandles(ctx, newRect);
            drawDeleteButton(ctx, newRect);
            setCurrentRect(newRect);
            return;
        }

        if (currentMode === DrawingMode.RECTANGLE && isDrawing.current && startPoint) {
            // Calculate rectangle dimensions
            const width = x - startPoint.x;
            const height = y - startPoint.y;
            
            // Clear canvas and draw the new rectangle
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            
            // Add transparent white fill
            ctx.fillStyle = 'rgba(255, 255, 255, 0.2)';
            ctx.fillRect(startPoint.x, startPoint.y, width, height);
            
            // Draw the stroke
            ctx.strokeStyle = drawColor;
            ctx.lineWidth = Math.max(2, getScaledLineWidth() / 2); // Thinner line for rectangles
            ctx.strokeRect(startPoint.x, startPoint.y, width, height);
            
            // Update current rectangle
            setCurrentRect({
                x: startPoint.x,
                y: startPoint.y,
                width: width,
                height: height
            });
        } else if (currentMode === DrawingMode.ERASER && isDrawing.current) {
            // Erase part of the stroke
            ctx.lineTo(x, y);
            ctx.stroke();
            setEraserPoints(prev => [...prev, { x, y }]);
        } else if (currentMode === DrawingMode.STROKE && isDrawing.current) {
            // Draw strokes
            ctx.lineTo(x, y);
            ctx.stroke();
            setBrushPoints(prev => [...prev, { x, y }]);
        }
    };

    const handleMouseUp = () => {
        // Always reset dragging and resizing states when mouse is released
        if (isDragging) {
            setIsDragging(false);
            setDragOffset({ x: 0, y: 0 });
            saveCanvasState();
        }
        
        if (isResizing) {
            setIsResizing(false);
            setResizeHandle(null);
            saveCanvasState();
        }
        
        if (isDrawing.current) {
            isDrawing.current = false;
            ctxRef.current.closePath();
            
            if (currentMode === DrawingMode.RECTANGLE && currentRect && !isResizing && !isDragging) {
                // Draw the rectangle with fill
                const ctx = ctxRef.current;
                
                // Redraw the rectangle with fill
                ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
                
                // Add transparent white fill
                ctx.fillStyle = 'rgba(255, 255, 255, 0.2)';
                ctx.fillRect(currentRect.x, currentRect.y, currentRect.width, currentRect.height);
                
                // Draw the stroke
                ctx.strokeStyle = drawColor;
                ctx.lineWidth = Math.max(2, getScaledLineWidth() / 2);
                ctx.strokeRect(currentRect.x, currentRect.y, currentRect.width, currentRect.height);
                
                // Draw resize handles and delete button for the rectangle
                drawResizeHandles(ctx, currentRect);
                drawDeleteButton(ctx, currentRect);
                
                // Add the rectangle points to brushPoints only when creating a new rectangle
                const { x, y, width, height } = currentRect;
                const points = [];
                
                // Add points along the rectangle edges
                for (let i = 0; i <= Math.abs(width); i += 5) {
                    points.push({ x: x + (width > 0 ? i : -i), y, isRectanglePoint: true });
                    points.push({ x: x + (width > 0 ? i : -i), y: y + height, isRectanglePoint: true });
                }
                for (let i = 0; i <= Math.abs(height); i += 5) {
                    points.push({ x: x, y: y + (height > 0 ? i : -i), isRectanglePoint: true });
                    points.push({ x: x + width, y: y + (height > 0 ? i : -i), isRectanglePoint: true });
                }
                
                setBrushPoints(prev => [...prev, ...points]);
            } else if (currentMode === DrawingMode.STROKE) {
                // Add a null marker to indicate the end of a stroke
                setBrushPoints(prev => [...prev, null]);
            } else if (currentMode === DrawingMode.ERASER) {
                // Add a null marker to indicate the end of an eraser stroke
                setEraserPoints(prev => [...prev, null]);
            }
            
            // Reset states
            setStartPoint(null);
            setIsResizing(false);
            setResizeHandle(null);
            setIsDragging(false);
            setDragOffset({ x: 0, y: 0 });
            
            saveCanvasState();
        }
    };

    const handleCanvasClick = (e) => {
        if (!isDrawingEnabled.current) return;
        
        const rect = canvasRef.current.getBoundingClientRect();
        const canvas = canvasRef.current;
        
        const scaleX = canvas.width / rect.width;
        const scaleY = canvas.height / rect.height;
        
        const x = (e.clientX - rect.left) * scaleX;
        const y = (e.clientY - rect.top) * scaleY;

        // Check if we're clicking the delete button for a rectangle
        if (currentMode === DrawingMode.RECTANGLE && currentRect) {
            const buttonSize = getButtonSize(canvas);
            const buttonX = currentRect.x + currentRect.width / 2 - buttonSize / 2;
            const buttonY = currentRect.y - buttonSize - 5;
            
            const dx = x - (buttonX + buttonSize / 2);
            const dy = y - (buttonY + buttonSize / 2);
            
            if (Math.sqrt(dx * dx + dy * dy) <= buttonSize / 2) {
                deleteCurrentRectangle();
                e.stopPropagation(); // Stop event propagation
                return;
            }
        }
        
        // Only draw points if we're in point mode
        if (currentMode === DrawingMode.POINT) {
            const ctx = ctxRef.current;
            
            // We're now handling canvas clearing in handleMouseDown
            // so we don't need to clear here
            
            setDrawingPoints(prev => [...prev, { x, y, type: 'click' }]);
            drawPoint(ctx, x, y);
            
            // Add the point to the appropriate array
            if (pointSubMode === PointSubMode.POSITIVE) {
                setPositivePoints(prev => [...prev, { x, y }]);
            } else if (pointSubMode === PointSubMode.NEGATIVE) {
                setNegativePoints(prev => [...prev, { x, y }]);
            }
            
            saveCanvasState();
        }
    };

    const resetCanvas = () => {
        resetAll();
    };

    const undoDrawing = () => {
        if (canvasStates.length > 0) {
            const canvas = canvasRef.current;
            const ctx = ctxRef.current;

            // Get the previous state
            const currentState = canvasStates[canvasStates.length - 1];
            const previousState = canvasStates.length > 1 ? canvasStates[canvasStates.length - 2] : null;

            if (canvas && ctx) {
                if (previousState) {
                    // If there's a previous state, load it
                    const img = new Image();
                    img.src = previousState;
                    img.onload = () => {
                        ctx.clearRect(0, 0, canvas.width, canvas.height);
                        ctx.drawImage(img, 0, 0);
                    };
                } else {
                    // If no previous state, clear the canvas
                    ctx.clearRect(0, 0, canvas.width, canvas.height);
                }

                // Update states
                setRedoStates(prev => [...prev, currentState]);
                setCanvasStates(prev => prev.slice(0, -1));
            }

            // Store current points for redo
            const currentPoints = {
                brush: [...brushPoints],
                eraser: [...eraserPoints],
                positive: [...positivePoints],
                negative: [...negativePoints],
                rect: currentRect
            };
            setRedoPoints(prev => [...prev, currentPoints]);

            // If we have a rectangle, check if it was the last thing added
            if (currentRect && brushPoints.some(point => point.isRectanglePoint)) {
                // Remove rectangle
                setCurrentRect(null);
                // Remove rectangle points
                setBrushPoints(prev => prev.filter(point => !point.isRectanglePoint));
            } 
            // Otherwise handle regular points
            else {
                // Check if the last action was adding a point
                const lastPositivePoint = positivePoints.length > 0 ? positivePoints[positivePoints.length - 1] : null;
                const lastNegativePoint = negativePoints.length > 0 ? negativePoints[negativePoints.length - 1] : null;
                
                // If we have positive or negative points, remove the last one added
                if (lastPositivePoint || lastNegativePoint) {
                    // Determine which point was added last
                    if (!lastNegativePoint || (lastPositivePoint && positivePoints.length > negativePoints.length)) {
                        setPositivePoints(prev => prev.slice(0, -1));
                    } else {
                        setNegativePoints(prev => prev.slice(0, -1));
                    }
                } 
                // Otherwise handle brush or eraser strokes
                else {
                    // For brush strokes, remove everything back to the last null marker
                    if (brushPoints.length > 0) {
                        setBrushPoints(prev => {
                            // Find the last null marker
                            const lastNullIndex = prev.lastIndexOf(null);
                            if (lastNullIndex !== -1) {
                                // Remove everything after the last null marker
                                return prev.slice(0, lastNullIndex);
                            } else {
                                // If no null marker, remove all points
                                return [];
                            }
                        });
                    }
                    
                    // For eraser strokes, remove everything back to the last null marker
                    if (eraserPoints.length > 0) {
                        setEraserPoints(prev => {
                            // Find the last null marker
                            const lastNullIndex = prev.lastIndexOf(null);
                            if (lastNullIndex !== -1) {
                                // Remove everything after the last null marker
                                return prev.slice(0, lastNullIndex);
                            } else {
                                // If no null marker, remove all points
                                return [];
                            }
                        });
                    }
                }
            }
        }
    };

    const redoDrawing = () => {
        if (redoStates.length > 0) {
            const nextState = redoStates[redoStates.length - 1];
            const nextPoints = redoPoints[redoPoints.length - 1];
            const canvas = canvasRef.current;
            const ctx = ctxRef.current;

            if (canvas && ctx) {
                const img = new Image();
                img.src = nextState;
                img.onload = () => {
                    ctx.clearRect(0, 0, canvas.width, canvas.height);
                    ctx.drawImage(img, 0, 0);
                };
                
                // Restore all points
                if (nextPoints) {
                    setBrushPoints(nextPoints.brush || []);
                    setEraserPoints(nextPoints.eraser || []);
                    setPositivePoints(nextPoints.positive || []);
                    setNegativePoints(nextPoints.negative || []);
                    
                    // Restore rectangle if it exists
                    if (nextPoints.rect) {
                        setCurrentRect(nextPoints.rect);
                    }
                }
                
                // Move the state back to the undo stack
                setCanvasStates(prev => [...prev, nextState]);
                setRedoStates(prev => prev.slice(0, -1));
                setRedoPoints(prev => prev.slice(0, -1));
            }
        }
    };

    const disableDrawing = () => {
        isDrawingEnabled.current = false;
        setCurrentMode(DrawingMode.NONE);
        if (ctxRef.current) {
            ctxRef.current.globalCompositeOperation = 'source-over';
        }
    };

    const handlePointClick = () => {
        // Don't allow point mode if there's a rectangle
        if (currentRect) {
            setErrorMessage('Please remove the rectangle before using point mode');
            setShowErrorMessage(true);
            return;
        }

        if (currentMode === DrawingMode.POINT) {
            disableDrawing();
            setPointSubMode(null);
        } else if (showPointSubmodeMenu) {
            setShowPointSubmodeMenu(false);
        } else {
            // Just switch to point mode without clearing anything yet
            setIsDrawingMode(false);
            setIsEraserMode(false);
            setIsRectangleMode(false);
            setCurrentMode(DrawingMode.NONE);
            setShowPointSubmodeMenu(true);
            
            // We'll clear strokes when a specific point submode is selected
        }
    };

    const handleDrawingClick = () => {
        // Don't allow brush mode if there's a rectangle
        if (currentRect) {
            setErrorMessage('Please remove the rectangle before using brush mode');
            setShowErrorMessage(true);
            return;
        }

        if (currentMode === DrawingMode.STROKE) {
            disableDrawing();
        } else {
            // Just switch to brush mode without clearing anything
            setIsPointMode(false);
            setPointSubMode(null);
            setShowPointSubmodeMenu(false);
            
            const scaledWidth = getScaledLineWidth();
            setLineWidth(scaledWidth);
            setDrawColor(drawColor);
            setFillStyleColor(fillStyleColor);
            enableBrushDrawing();
        }
    };

    const handleEraserClick = () => {
        // Don't allow eraser mode if there's a rectangle
        if (currentRect) {
            setErrorMessage('Please remove the rectangle before using eraser mode');
            setShowErrorMessage(true);
            return;
        }

        if (currentMode === DrawingMode.ERASER) {
            disableDrawing();
        } else {
            // Just switch to eraser mode without clearing anything
            setIsPointMode(false);
            setPointSubMode(null);
            setShowPointSubmodeMenu(false);
            
            const scaledWidth = getScaledLineWidth();
            setLineWidth(scaledWidth);
            setEraserColor(eraserColor);
            setFillStyleColor(fillStyleColor);
            enableEraser();
        }
    };

    const handleRectangleClick = () => {
        if (currentMode === DrawingMode.RECTANGLE) {
            disableDrawing();
            // Clear the current rectangle when disabling rectangle mode
            if (currentRect) {
                deleteCurrentRectangle();
            }
        } else {
            // Don't allow creating a new rectangle if one already exists
            if (currentRect) {
                setErrorMessage('Please remove the existing rectangle before drawing a new one');
                setShowErrorMessage(true);
                return;
            }
            
            // Just switch to rectangle mode without clearing anything
            setIsPointMode(false);
            setPointSubMode(null);
            setShowPointSubmodeMenu(false);
            enableRectangleDrawing();
        }
    };

    const resizeImage = async (imageFile, dimensions) => {
        return new Promise((resolve, reject) => {
            const img = new Image();
            img.src = URL.createObjectURL(imageFile);

            img.onload = () => {
                const tempCanvas = document.createElement('canvas');
                tempCanvas.width = dimensions.width;
                tempCanvas.height = dimensions.height;
                const tempCtx = tempCanvas.getContext('2d');

                // Draw the image onto the temporary canvas
                tempCtx.drawImage(img, 0, 0, dimensions.width, dimensions.height);

                // Convert the canvas to a Blob
                tempCanvas.toBlob((blob) => {
                    if (blob) {
                        resolve(blob);
                    } else {
                        reject(new Error('Failed to resize image'));
                    }
                }, 'image/png');
            };

            img.onerror = () => {
                reject(new Error('Failed to load image'));
            };
        });
    };

    const draw2DArrayOnCanvas = (mask, index = 0, binary = false) => {
        const canvas = document.createElement('canvas');
        canvas.width = mask[0].length;
        canvas.height = mask.length;
        const ctx = canvas.getContext('2d');
        
        // Create ImageData object
        const imageData = ctx.createImageData(canvas.width, canvas.height);
        const data = imageData.data;

        // Fill the imageData
        for (let y = 0; y < mask.length; y++) {
            for (let x = 0; x < mask[0].length; x++) {
                const pixelIndex = (y * mask[0].length + x) * 4;
                if(binary){
                    let value = mask[y][x] ? 255 : 0;
                    data[pixelIndex] = value;
                    data[pixelIndex + 1] = value;
                    data[pixelIndex + 2] = value;
                    data[pixelIndex + 3] = 255;
                }else{
                    let value = mask[y][x] ? 255 : 0;
                    if (value === 255) {
                        data[pixelIndex] = 244;     // R
                        data[pixelIndex + 1] = 63;  // G
                        data[pixelIndex + 2] = 94;  // B
                        data[pixelIndex + 3] = 128; // A (semi-transparent)
                    } else {
                        data[pixelIndex] = 0;
                        data[pixelIndex + 1] = 0;
                        data[pixelIndex + 2] = 0;
                        data[pixelIndex + 3] = 0;
                    }
                }
            }
        }

        // Put the ImageData on the canvas
        ctx.putImageData(imageData, 0, 0);

        // Add border around the mask
        if (!binary) {
            // Find mask boundaries
            let left = canvas.width;
            let right = 0;
            let top = canvas.height;
            let bottom = 0;

            for (let y = 0; y < mask.length; y++) {
                for (let x = 0; x < mask[0].length; x++) {
                    if (mask[y][x]) {
                        left = Math.min(left, x);
                        right = Math.max(right, x);
                        top = Math.min(top, y);
                        bottom = Math.max(bottom, y);
                    }
                }
            }

            // Draw border
            ctx.strokeStyle = 'rgba(244, 63, 94, 1)'; // Solid rose color
            ctx.lineWidth = 2;
            ctx.beginPath();

            // Draw border segments with a small gap
            for (let y = top; y <= bottom; y++) {
                for (let x = left; x <= right; x++) {
                    if (mask[y][x]) {
                        // Check neighbors
                        const hasLeft = x > 0 && !mask[y][x-1];
                        const hasRight = x < mask[0].length-1 && !mask[y][x+1];
                        const hasTop = y > 0 && !mask[y-1][x];
                        const hasBottom = y < mask.length-1 && !mask[y+1][x];

                        if (hasLeft || hasRight || hasTop || hasBottom) {
                            if (hasLeft) {
                                ctx.moveTo(x, y);
                                ctx.lineTo(x, y + 1);
                            }
                            if (hasRight) {
                                ctx.moveTo(x + 1, y);
                                ctx.lineTo(x + 1, y + 1);
                            }
                            if (hasTop) {
                                ctx.moveTo(x, y);
                                ctx.lineTo(x + 1, y);
                            }
                            if (hasBottom) {
                                ctx.moveTo(x, y + 1);
                                ctx.lineTo(x + 1, y + 1);
                            }
                        }
                    }
                }
            }
            ctx.stroke();
        }

        return canvas;
    };

    const binaryArrayToFile = (mask) => {
        const height = mask.length;
        const width = mask[0].length;

        const canvas = document.createElement('canvas');
        canvas.width = width;
        canvas.height = height;
        const ctx1 = canvas.getContext('2d');

        const maskCanvas = draw2DArrayOnCanvas(mask, 0, true);
        ctx1.drawImage(maskCanvas, 0, 0);

        return canvas.toDataURL('image/png');
    }

    const handleGenerateMaskClick = async () => {
        console.log('Generate Mask');
        console.log(drawingPoints);

        setShowErrorMessage(false);
        setErrorMessage('');

        if(!authenticated) {
            setErrorMessage('Please sign in to generate a mask.');
            setShowErrorMessage(true);
            return;
        }

        if(isGeneratingMask) {
            return;
        }

        let result = processDrawingPoints();
        let input_points = result.input_points;
        let input_labels = result.input_labels;

        if(input_points.length === 0 && prompt === '') {
            setErrorMessage('Draw points or enter a prompt to generate a mask.');
            setShowErrorMessage(true);
            return;
        }

        setIsGeneratingMask(true);

        // convert image to a file format that can be sent to the backend
        console.log("currentImage", currentImage);

        console.log("currentImage", currentImage, canvasRef.current.width, canvasRef.current.height);

        // Use the stored dimensions instead of reading from canvas
        const image_dimensions = {
            width: imageDimensions.width,
            height: imageDimensions.height
        };
        console.log("image_dimensions", image_dimensions);

        // convert a blob to a file
        const image_file = await fetch(currentImage)

        const image_file_blob = await image_file.blob();
        console.log(image_file_blob);

        const image_file_file = new File([image_file_blob], 'image.png', { type: 'image/png' });
        console.log(image_file_file);

        // resize the image to the dimensions of the canvas
        const resized_image = await resizeImage(image_file_file, image_dimensions);
        console.log(resized_image);

        // convert the resized image to a file that can be sent to the backend
        const resized_image_file = new File([resized_image], 'image.png', { type: 'image/png' });
        console.log(resized_image_file);

        //const image_file = new File([image], 'image.png', { type: 'image/png' });

        // axios post a file image and a form data object
        const formData = new FormData();
        formData.append('image', resized_image_file);
        formData.append('experiment_id', 233326);
        if(input_points.length > 0) {
            formData.append('input_point', JSON.stringify(input_points));
            formData.append('input_label', JSON.stringify(input_labels));
        }

        if(prompt !== '') {
            // check the last letter of the prompt and if there is not dot, add it
            let newPrompt = prompt;
            if(prompt[prompt.length - 1] !== '.') {
                newPrompt += '.';
            }
            formData.append('prompt', newPrompt);
        }

        try {
            const response = await fetch(`${process.env.REACT_APP_TRYON_SERVER_URL}/api/v1/masks/`, {
                method: 'POST',
                body: formData,
                headers: {
                    'Authorization': `Bearer ${accessToken}`
                }
            });

            if (response.ok) {
                const data = await response.json();
                console.log(data);
                console.log('Mask generated successfully');
                const masks = data.data.masks; // Assuming this is the 2D binary array
                console.log(masks, masks.length);

                // Create a canvas overlay for masks
                // if (canvasRef.current) {
                //     const ctx = canvasRef.current.getContext('2d');
                //     ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);

                //     // Draw each mask with a different color
                //     masks.forEach((mask, index) => {
                //         const maskCanvas = draw2DArrayOnCanvas(mask, index, true);
                //         ctx.drawImage(maskCanvas, 0, 0);
                //     });

                //     // Save this state to canvas states
                //     saveCanvasState();
                // }

                const mergedMask = mergeMasks(masks);
                console.log("mergedMask", mergedMask);

                // draw merged mask on the canvas
                if (canvasRef.current) {
                    const ctx = canvasRef.current.getContext('2d');
                    ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
                    const maskCanvas = draw2DArrayOnCanvas(mergedMask, 0, true);
                    ctx.drawImage(maskCanvas, 0, 0);

                    saveCanvasState();
                }

                // Set the first mask as the main mask image for further processing
                setBinaryMaskImage(binaryArrayToFile(mergedMask));
                //setMaskImage(binaryArrayToFile(mergedMask));

                setIsMaskGenerated(true);
                setIsGeneratingMask(false);
                // reset drawing points
                setDrawingPoints([]);
                setBrushPoints([]);
                setEraserPoints([]);
                setPrompt('');
                setIsDrawingMode(false);
                setIsEraserMode(false);
            }else{
                console.log('Failed to generate mask');
                setErrorMessage('Failed to generate mask. Please try again.');
                setShowErrorMessage(true);
                setIsGeneratingMask(false);
            }
        } catch (error) {
            console.error('Error:', error);
            setErrorMessage('Failed to generate mask. Please try again.');
            setShowErrorMessage(true);
            setIsGeneratingMask(false);
        }
    };

    const handleContinueClick = () => {
        console.log('Continue');

        setErrorMessage('');
        setShowErrorMessage(false);

        if(!isMaskGenerated && !currentRect && brushPoints.length === 0) {
            setErrorMessage('No mask generated or drawn. Please generate or draw a mask first.');
            setShowErrorMessage(true);
            return;
        }

        // check if there are no points on the canvas
        const ctx = canvasRef.current.getContext('2d');
        const imageData = ctx.getImageData(0, 0, canvasRef.current.width, canvasRef.current.height);
        const data = imageData.data;
        let isNotEmpty = false;
        for(let i = 0; i < data.length; i += 4) {
            if(data[i] !== 0 || data[i + 1] !== 0 || data[i + 2] !== 0) {
                isNotEmpty = true;
                break;
            }
        }

        if (isMaskGenerated && brushPoints.length === 0 && isNotEmpty) {
            cleanupOnClose(false);
            setShowMaskEditor(false);
            return;
        }

        if(currentRect) {
            // Create a binary mask with black background and white rectangle
            const ctx = canvasRef.current.getContext('2d');
            ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
            
            // Fill the entire canvas with black color
            ctx.fillStyle = 'black';
            ctx.fillRect(0, 0, canvasRef.current.width, canvasRef.current.height);
            
            // Draw the rectangle with white color
            ctx.fillStyle = 'white';
            ctx.fillRect(currentRect.x, currentRect.y, currentRect.width, currentRect.height);
            
            // Save the binary mask
            setBinaryMaskImage(ctx.canvas.toDataURL("image/png"));
            cleanupOnClose(false);
            setShowMaskEditor(false);
            return;
        }

        if(brushPoints.length > 0) {
            // draw brush points on the canvas
            if(isMaskGenerated) {
                //use the current canvas as the mask
                setBinaryMaskImage(canvasRef.current.toDataURL("image/png"));
            } else {
                const canvas = drawBrushPoints(brushPoints, canvasRef.current);
                setBinaryMaskImage(canvas.toDataURL("image/png"));
            }
            cleanupOnClose(false);
            setShowMaskEditor(false);
            return;
        }
        
        cleanupOnClose(false);
        setShowMaskEditor(false);
    };

    const drawBrushPoints = (points, canvas) => {
        const ctx = canvas.getContext('2d');
        ctx.clearRect(0, 0, canvas.width, canvas.height);
        ctx.fillStyle = 'black';
        ctx.fillRect(0, 0, canvas.width, canvas.height);

        // Set up for drawing new strokes
        ctx.strokeStyle = 'white';
        ctx.fillStyle = 'white';
        ctx.lineWidth = getScaledLineWidth();
        ctx.lineCap = 'round';

        // draw the brush points
        points.forEach((point, index) => {
            if (point === null) {
                ctx.stroke();
                return;
            }
            
            if (index === 0 || brushPoints[index - 1] === null) {
                ctx.beginPath();
                ctx.moveTo(point.x, point.y);
            } else {
                ctx.lineTo(point.x, point.y);
                ctx.stroke();
                ctx.beginPath();
                ctx.moveTo(point.x, point.y);
            }
        });
        ctx.stroke();
        return canvas;
    }

    const processDrawingPoints = () => {
        let input_labels = [];  
        let input_points = [];

        // Process brush points (always positive)
        for (let i = 0; i < positivePoints.length; i++) {
            const point = positivePoints[i];
            if (point.type === 'point') {
                input_labels.push(point.isPositive ? 1 : 0);
            } else {
                input_labels.push(1);
            }
            input_points.push([point.x, point.y]);
        }

        // Process eraser points (always negative)
        for (let i = 0; i < negativePoints.length; i++) {
            input_labels.push(0);
            input_points.push([negativePoints[i].x, negativePoints[i].y]);
        }

        return {
            input_points: input_points,
            input_labels: input_labels
        };
    };

    const mergeMasks = (masks) => {
        if (!masks || masks.length === 0) return null;
        
        const height = masks[0].length;
        const width = masks[0][0].length;
        
        // Initialize merged mask with the first mask's values
        const mergedMask = Array(height).fill().map((_, y) => 
            Array(width).fill().map((_, x) => masks[0][y][x])
        );
        
        // Combine all subsequent masks using OR operation
        for (let i = 1; i < masks.length; i++) {
            for (let y = 0; y < height; y++) {
                for (let x = 0; x < width; x++) {
                    // Merge using OR operation (|)
                    mergedMask[y][x] = mergedMask[y][x] | masks[i][y][x];
                }
            }
        }
        
        return mergedMask;
    };

    const enableRectangleDrawing = () => {
        isDrawingEnabled.current = true;
        setCurrentMode(DrawingMode.RECTANGLE);
        if (ctxRef.current) {
            ctxRef.current.globalCompositeOperation = 'source-over';
            // Use a thinner line for rectangles
            ctxRef.current.lineWidth = Math.max(2, getScaledLineWidth() / 2);
        }
    };

    // Helper function to draw resize handles
    const drawResizeHandles = (ctx, rect) => {
        const canvas = canvasRef.current;
        const handleSize = getHandleSize(canvas);
        
        const handles = [
            { x: rect.x, y: rect.y, cursor: 'nw-resize' },
            { x: rect.x + rect.width, y: rect.y, cursor: 'ne-resize' },
            { x: rect.x, y: rect.y + rect.height, cursor: 'sw-resize' },
            { x: rect.x + rect.width, y: rect.y + rect.height, cursor: 'se-resize' }
        ];

        ctx.fillStyle = 'white';
        handles.forEach(handle => {
            ctx.beginPath();
            ctx.arc(handle.x, handle.y, handleSize/2, 0, Math.PI * 2);
            ctx.fill();
            ctx.stroke();
        });

        return handles;
    };

    // Helper function to check if a point is near a handle
    const getResizeHandle = (x, y, rect) => {
        if (!rect) return null;
        const canvas = canvasRef.current;
        const handleSize = getHandleSize(canvas);

        const handles = [
            { x: rect.x, y: rect.y, cursor: 'nw-resize', type: 'nw' },
            { x: rect.x + rect.width, y: rect.y, cursor: 'ne-resize', type: 'ne' },
            { x: rect.x, y: rect.y + rect.height, cursor: 'sw-resize', type: 'sw' },
            { x: rect.x + rect.width, y: rect.y + rect.height, cursor: 'se-resize', type: 'se' }
        ];

        for (const handle of handles) {
            const dx = x - handle.x;
            const dy = y - handle.y;
            if (Math.sqrt(dx * dx + dy * dy) <= handleSize) {
                return handle;
            }
        }
        return null;
    };

    // Helper function to update rectangle dimensions during resize
    const updateRectangleOnResize = (x, y) => {
        if (!currentRect || !resizeHandle) return;

        const newRect = { ...currentRect };

        switch (resizeHandle.type) {
            case 'nw':
                newRect.width = currentRect.x + currentRect.width - x;
                newRect.height = currentRect.y + currentRect.height - y;
                newRect.x = x;
                newRect.y = y;
                break;
            case 'ne':
                newRect.width = x - currentRect.x;
                newRect.height = currentRect.y + currentRect.height - y;
                newRect.y = y;
                break;
            case 'sw':
                newRect.width = currentRect.x + currentRect.width - x;
                newRect.height = y - currentRect.y;
                newRect.x = x;
                break;
            case 'se':
                newRect.width = x - currentRect.x;
                newRect.height = y - currentRect.y;
                break;
        }

        return newRect;
    };

    // Update deleteCurrentRectangle function to be more robust
    const deleteCurrentRectangle = () => {
        const canvas = canvasRef.current;
        const ctx = ctxRef.current;
        
        if (!canvas || !ctx) return;
        
        // Clear the canvas
        ctx.clearRect(0, 0, canvas.width, canvas.height);
        
        // Reset rectangle state
        setCurrentRect(null);
        setStartPoint(null);
        setIsResizing(false);
        setResizeHandle(null);
        setIsDragging(false);
        
        // Remove rectangle points from brushPoints
        setBrushPoints(prev => prev.filter(point => !point.isRectanglePoint));
        
        // Save the empty state
        saveCanvasState();
    };

    // Add helper function for button size
    const getButtonSize = (canvas) => {
        if (!canvas) return 16;
        const rect = canvas.getBoundingClientRect();
        const scaleX = canvas.width / rect.width;
        const VIEWPORT_BUTTON_SIZE = 16; // Size in viewport pixels
        return VIEWPORT_BUTTON_SIZE * scaleX;
    };

    // Add helper function for handle size
    const getHandleSize = (canvas) => {
        if (!canvas) return 8;
        const rect = canvas.getBoundingClientRect();
        const scaleX = canvas.width / rect.width;
        const VIEWPORT_HANDLE_SIZE = 8; // Size in viewport pixels
        return VIEWPORT_HANDLE_SIZE * scaleX;
    };

    // Update drawDeleteButton function
    const drawDeleteButton = (ctx, rect) => {
        const canvas = canvasRef.current;
        const buttonSize = getButtonSize(canvas);
        const buttonX = rect.x + rect.width / 2 - buttonSize / 2;
        const buttonY = rect.y - buttonSize - 5;

        // Draw delete button circle
        ctx.fillStyle = 'rgba(255, 255, 255, 0.9)';
        ctx.strokeStyle = 'rgba(255, 255, 255, 1)';
        ctx.beginPath();
        ctx.arc(buttonX + buttonSize / 2, buttonY + buttonSize / 2, buttonSize / 2, 0, Math.PI * 2);
        ctx.fill();
        ctx.stroke();

        // Draw X symbol
        ctx.beginPath();
        ctx.strokeStyle = '#000';
        ctx.lineWidth = Math.max(1, buttonSize / 8); // Scale line width with button size
        const offset = buttonSize / 4;
        ctx.moveTo(buttonX + offset, buttonY + offset);
        ctx.lineTo(buttonX + buttonSize - offset, buttonY + buttonSize - offset);
        ctx.moveTo(buttonX + buttonSize - offset, buttonY + offset);
        ctx.lineTo(buttonX + offset, buttonY + buttonSize - offset);
        ctx.stroke();
    };

    // Add this function after other utility functions
    const invertMask = () => {
        console.log('Invert mask');
        console.log(isMaskGenerated);
        if(!isMaskGenerated) {
            setShowErrorMessage(true);
            setErrorMessage("There is no mask generated. Please generate a mask using the generate mask button first");
            return;
        }

        const canvas = canvasRef.current;
        const ctx = canvas.getContext('2d');
        
        // Get the current image data
        const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
        const data = imageData.data;
        
        // Invert each pixel
        for (let i = 0; i < data.length; i += 4) {
            // Invert RGB values (0 becomes 255, 255 becomes 0)
            data[i] = 255 - data[i];         // Red
            data[i + 1] = 255 - data[i + 1]; // Green
            data[i + 2] = 255 - data[i + 2]; // Blue
            // Keep alpha channel unchanged
        }
        
        // Put the inverted image data back on the canvas
        ctx.putImageData(imageData, 0, 0);
        
        // Update the mask images
        setBinaryMaskImage(canvas.toDataURL("image/png"));
        setIsMaskInverted(!isMaskInverted);
        
        // Save the new state
        saveCanvasState();
    };

    return (
        <div 
            className="fixed inset-0 z-50 flex items-center justify-center bg-black bg-opacity-75 touch-none overflow-y-auto" 
            onClick={() => {
                cleanupOnClose(true);
                setShowMaskEditor(false);
            }}
        >
            <div 
                className="relative bg-white w-[90vw] lg:w-[50vw] mx-4 lg:w-7xl h-5xl max-h-[90vh] rounded-lg overflow-hidden p-4 shadow-lg overflow-y-auto" 
                onClick={(e) => e.stopPropagation()}
            >
                <div className="flex flex-col items-center justify-center mx-auto w-auto lg:w-full">
                    <div className="flex flex-row items-center justify-between w-full">
                        <div className="w-full text-start flex items-center font-semibold text-xl font-Mulish">Generate/Draw Mask</div>
                        <button className="w-6 h-auto bg-white shadow-lg rounded-full p-1 cursor-pointer" 
                            onClick={() => {
                                cleanupOnClose(true);
                                setShowMaskEditor(false);
                            }}>
                            <img src={CloseIcon} alt="close icon" />
                        </button>
                    </div>

                    <p className="w-full my-2 text-start font-semibold mb-4 text-sm lg:text-base text-gray-500 flex-wrap">Draw points on the image or enter a prompt to generate a mask. Draw a rectangle to select an area of the image to mask for generation.</p>

                    <div className="relative overflow-hidden">
                        {isDrawing && (
                            <canvas
                                ref={canvasRef}
                                className="h-full w-full object-cover object-top rounded-lg absolute top-0 left-0"
                                style={{ 
                                    zIndex: 1,
                                    cursor: currentMode === DrawingMode.RECTANGLE 
                                        ? 'crosshair'
                                        : currentMode === DrawingMode.STROKE || currentMode === DrawingMode.ERASER
                                            ? 'crosshair'
                                            : 'default'
                                }}
                                onMouseDown={handleMouseDown}
                                onMouseMove={handleMouseMove}
                                onMouseUp={handleMouseUp}
                                onMouseLeave={handleMouseUp}
                                onClick={handleCanvasClick}
                                onTouchStart={(e) => {
                                    e.preventDefault(); // Prevent scrolling while drawing
                                    const touch = e.touches[0];
                                    const mouseEvent = new MouseEvent('mousedown', {
                                        clientX: touch.clientX,
                                        clientY: touch.clientY
                                    });
                                    handleMouseDown(mouseEvent);
                                }}
                                onTouchMove={(e) => {
                                    e.preventDefault();
                                    const touch = e.touches[0];
                                    const mouseEvent = new MouseEvent('mousemove', {
                                        clientX: touch.clientX,
                                        clientY: touch.clientY
                                    });
                                    handleMouseMove(mouseEvent);
                                }}
                                onTouchEnd={(e) => {
                                    e.preventDefault();
                                    handleMouseUp();
                                }}
                            />
                        )}
                        <img 
                            src={currentImage} 
                            className="h-auto max-h-[50vh] w-auto max-w-[80vw] mx-auto object-cover object-top rounded-lg" 
                            alt="image"
                        />
                    </div>
                 
                    <div className="flex flex-row flex-wrap gap-2 lg:gap-0 items-center justify-center space-x-4 rounded-lg border-2 border-gray-200/50 mt-2 p-3 bg-gray-50">
                        <Tooltip content="Draw Points" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <div className="relative">
                                <button 
                                className={`flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none
                                    ${currentMode === DrawingMode.POINT ? 'ring-2 ring-rose-500 bg-rose-50' : 'hover:bg-gray-50'}`} 
                                onClick={handlePointClick}
                            >
                                {pointSubMode === PointSubMode.POSITIVE ? (
                                    <PiCircleFill className="text-green-500" />
                                ) : pointSubMode === PointSubMode.NEGATIVE ? (
                                    <PiCircleFill className="text-red-500" />
                                ) : (
                                    <PiCircle className="text-gray-500" />
                                )}
                            </button>
                            
                            {showPointSubmodeMenu && (
                                <div ref={dropdownRef} className="absolute left-1/2 -translate-x-1/2 top-full mt-2 bg-white rounded-lg shadow-2xl z-50 border border-gray-100 overflow-hidden transform origin-top transition-all duration-200 ease-out animate-fadeIn flex flex-col">
                                    <div className="flex justify-between items-center border border-gray-200/50 rounded-lg">
                                        <button
                                            className="flex items-center justify-between px-4 py-3 hover:bg-gray-100 transition-colors duration-200 border-r border-gray-200/50"
                                            onClick={() => {
                                                setPointSubMode(PointSubMode.POSITIVE);
                                                setShowPointSubmodeMenu(false);
                                                enablePointDrawing();
                                            }}
                                        >
                                            <div className="flex items-center space-x-2">
                                                <div className="w-5 h-5 relative">
                                                    <div className="w-3 h-3 bg-green-500 rounded-full absolute top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2"></div>
                                                    <div className="w-5 h-5 bg-green-200 rounded-full"></div>
                                                </div>
                                            </div>
                                        </button>
                                        <button
                                            className="flex items-center justify-between px-4 py-3 hover:bg-gray-100 transition-colors duration-200"
                                            onClick={() => {
                                                setPointSubMode(PointSubMode.NEGATIVE);
                                                setShowPointSubmodeMenu(false);
                                                enablePointDrawing();
                                            }}
                                        >
                                            <div className="flex items-center space-x-2">
                                                <div className="w-5 h-5 relative">
                                                    <div className="w-3 h-3 bg-red-500 rounded-full absolute top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2"></div>
                                                    <div className="w-5 h-5 bg-red-200 rounded-full"></div>
                                                </div>
                                            </div>
                                        </button>
                                    </div>
                                </div>
                            )}
                        </div>
                        </Tooltip>

                        <Tooltip content="Draw Strokes" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className={`flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none
                                    ${currentMode === DrawingMode.STROKE ? 'ring-2 ring-rose-500 bg-rose-50' : 'hover:bg-gray-50'}`} 
                                onClick={handleDrawingClick}
                            >
                                <PiPaintBrushLight />
                            </button>
                        </Tooltip>

                        <Tooltip content="Eraser" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className={`flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none
                                    ${currentMode === DrawingMode.ERASER ? 'ring-2 ring-rose-500 bg-rose-50' : 'hover:bg-gray-50'}`}
                                onClick={handleEraserClick}
                            >
                                <PiEraserLight />
                            </button>
                        </Tooltip>

                        <Tooltip content="Rectangle" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className={`flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none
                                    ${currentMode === DrawingMode.RECTANGLE ? 'ring-2 ring-rose-500 bg-rose-50' : 'hover:bg-gray-50'}`}
                                onClick={handleRectangleClick}
                            >
                                <PiSquare />
                            </button>
                        </Tooltip>

                        <Tooltip content="Undo" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className="flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none hover:bg-gray-50" 
                                onClick={undoDrawing}
                            >
                                <SlActionUndo />
                            </button>
                        </Tooltip>

                        <Tooltip content="Redo" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className="flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none hover:bg-gray-50" 
                                onClick={redoDrawing}
                            >
                                <SlActionRedo />
                            </button>
                        </Tooltip>

                        <Tooltip content="Reset" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className="flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform hover:scale-105 focus:outline-none hover:bg-gray-50"
                                onClick={resetCanvas}
                            >
                                <RxReset />
                            </button>
                        </Tooltip>

                        <Tooltip content="Invert Mask" className="text-xs bg-gray-800 text-white p-1 rounded">
                            <button 
                                className={`flex items-center justify-center bg-white shadow-md rounded-full p-2 transition-all duration-200 transform 
                                    ${(isDrawingMode || isPointMode || isEraserMode || isRectangleMode) ? 'opacity-50 cursor-not-allowed' : 'hover:scale-105 hover:bg-gray-50 cursor-pointer'}`}
                                onClick={invertMask}
                                disabled={isDrawingMode || isPointMode || isEraserMode || isRectangleMode}
                            >
                                <PiArrowsInSimpleLight />
                            </button>
                        </Tooltip>
                    </div>

                    <p className="w-full text-center py-2 font-semibold text-xs text-gray-500 lg:text-sm font-Mulish underline">Or</p>

                    <input type="text" className="w-full px-3 py-2 border border-gray-200 rounded-lg focus:outline-none focus:border-rose-300 focus:ring-1 focus:ring-rose-300 text-gray-700" value={prompt} onChange={(e) => setPrompt(e.target.value)} placeholder="Enter prompt to generate mask e.g. 't-shirt, jeans, shoes, etc.'" />

                    <div className="flex flex-col lg:flex-row items-center justify-between w-full gap-4 mt-4">
                        <button className="px-4 py-2 bg-gradient-to-r from-rose-400 to-rose-500 hover:from-rose-500 hover:to-rose-600 text-white rounded-lg text-sm font-medium transition-all duration-200 shadow-sm hover:shadow" onClick={handleGenerateMaskClick}>{isGeneratingMask ? <div className="flex items-center justify-center"> <span className="mr-2">Generating Mask...</span> <Spinner width="4" height="4" fill="fill-white" /></div> : 'Generate Mask'}</button>
                        <button className={`px-4 py-2 bg-gradient-to-r ${isMaskGenerated || currentRect || brushPoints.length > 0 ? 'from-rose-500 to-rose-600 hover:from-rose-500 hover:to-rose-600' : 'from-gray-400 to-gray-500 hover:from-gray-500 hover:to-gray-500'} text-white rounded-lg text-sm font-medium transition-all duration-200 shadow-sm hover:shadow`} onClick={handleContinueClick}>Continue With Mask</button>
                    </div>

                    <p className="w-full my-2 text-start font-semibold mt-4 text-xs lg:text-sm font-Mulish text-gray-500 flex-wrap flex-wrap">Note: Create a mask for the area you want to use for generation. After the mask is generated, the white area will be used for generation. The black area will not be used.</p>
                </div>
            </div>

            {showErrorMessage && (
                <ToastErrorMessage errorMessage={errorMessage} errorStatus={"Error"} handleErrorClose={(e) => {
                    e.stopPropagation();
                    setShowErrorMessage(false);
                }} />
            )}
        </div>
    );
}