diff --git a/custom-demo/front-end/src/App.tsx b/custom-demo/front-end/src/App.tsx index 6650840..7b57f6a 100644 --- a/custom-demo/front-end/src/App.tsx +++ b/custom-demo/front-end/src/App.tsx @@ -1,167 +1,169 @@ -import { useCallback, useEffect, useRef } from "react" +import { useCallback, useEffect, useRef } from "react"; -import useInputImage from "@/hooks/useInputImage" -import { keepGUIAlive } from "@/lib/utils" -import { getServerConfig } from "@/lib/api" -import Header from "@/components/Header" -import Workspace from "@/components/Workspace" -import FileSelect from "@/components/FileSelect" -import { Toaster } from "./components/ui/toaster" -import { useStore } from "./lib/states" -import { useWindowSize } from "react-use" +import useInputImage from "@/hooks/useInputImage"; +import { keepGUIAlive } from "@/lib/utils"; +import { getServerConfig } from "@/lib/api"; +import Header from "@/components/Header"; +import Workspace from "@/components/Workspace"; +import FileSelect from "@/components/FileSelect"; +import { Toaster } from "./components/ui/toaster"; +import { useStore } from "./lib/states"; +import { useWindowSize } from "react-use"; const SUPPORTED_FILE_TYPE = [ - "image/jpeg", - "image/png", - "image/webp", - "image/bmp", - "image/tiff", -] + "image/jpeg", + "image/png", + "image/webp", + "image/bmp", + "image/tiff", +]; function Home() { - const [file, updateAppState, setServerConfig, setFile] = useStore((state) => [ - state.file, - state.updateAppState, - state.setServerConfig, - state.setFile, - ]) + const [file, updateAppState, setServerConfig, setFile] = useStore( + (state) => [ + state.file, + state.updateAppState, + state.setServerConfig, + state.setFile, + ], + ); - const userInputImage = useInputImage() + const userInputImage = useInputImage(); - const windowSize = useWindowSize() + const windowSize = useWindowSize(); - useEffect(() => { - if (userInputImage) { - setFile(userInputImage) - } - }, [userInputImage, setFile]) + useEffect(() => { + if (userInputImage) { + setFile(userInputImage); + } + }, [userInputImage, setFile]); - useEffect(() => { - updateAppState({ windowSize }) - }, [windowSize]) + useEffect(() => { + updateAppState({ windowSize }); + }, [windowSize]); - useEffect(() => { - const fetchServerConfig = async () => { - const serverConfig = await getServerConfig() - setServerConfig(serverConfig) - if (serverConfig.isDesktop) { - // Keeping GUI Window Open - keepGUIAlive() - } - } - fetchServerConfig() - }, []) + useEffect(() => { + const fetchServerConfig = async () => { + const serverConfig = await getServerConfig(); + setServerConfig(serverConfig); + if (serverConfig.isDesktop) { + // Keeping GUI Window Open + keepGUIAlive(); + } + }; + fetchServerConfig(); + }, []); - const dragCounter = useRef(0) + const dragCounter = useRef(0); - const handleDrag = useCallback((event: any) => { - event.preventDefault() - event.stopPropagation() - }, []) + const handleDrag = useCallback((event: any) => { + event.preventDefault(); + event.stopPropagation(); + }, []); - const handleDragIn = useCallback((event: any) => { - event.preventDefault() - event.stopPropagation() - dragCounter.current += 1 - }, []) + const handleDragIn = useCallback((event: any) => { + event.preventDefault(); + event.stopPropagation(); + dragCounter.current += 1; + }, []); - const handleDragOut = useCallback((event: any) => { - event.preventDefault() - event.stopPropagation() - dragCounter.current -= 1 - if (dragCounter.current > 0) return - }, []) + const handleDragOut = useCallback((event: any) => { + event.preventDefault(); + event.stopPropagation(); + dragCounter.current -= 1; + if (dragCounter.current > 0) return; + }, []); - const handleDrop = useCallback((event: any) => { - event.preventDefault() - event.stopPropagation() - if (event.dataTransfer.files && event.dataTransfer.files.length > 0) { - if (event.dataTransfer.files.length > 1) { - // setToastState({ - // open: true, - // desc: "Please drag and drop only one file", - // state: "error", - // duration: 3000, - // }) - } else { - const dragFile = event.dataTransfer.files[0] - const fileType = dragFile.type - if (SUPPORTED_FILE_TYPE.includes(fileType)) { - setFile(dragFile) - } else { - // setToastState({ - // open: true, - // desc: "Please drag and drop an image file", - // state: "error", - // duration: 3000, - // }) - } - } - event.dataTransfer.clearData() - } - }, []) + const handleDrop = useCallback((event: any) => { + event.preventDefault(); + event.stopPropagation(); + if (event.dataTransfer.files && event.dataTransfer.files.length > 0) { + if (event.dataTransfer.files.length > 1) { + // setToastState({ + // open: true, + // desc: "Please drag and drop only one file", + // state: "error", + // duration: 3000, + // }) + } else { + const dragFile = event.dataTransfer.files[0]; + const fileType = dragFile.type; + if (SUPPORTED_FILE_TYPE.includes(fileType)) { + setFile(dragFile); + } else { + // setToastState({ + // open: true, + // desc: "Please drag and drop an image file", + // state: "error", + // duration: 3000, + // }) + } + } + event.dataTransfer.clearData(); + } + }, []); - const onPaste = useCallback((event: any) => { - // TODO: when sd side panel open, ctrl+v not work - // https://htmldom.dev/paste-an-image-from-the-clipboard/ - if (!event.clipboardData) { - return - } - const clipboardItems = event.clipboardData.items - const items: DataTransferItem[] = [].slice - .call(clipboardItems) - .filter((item: DataTransferItem) => { - // Filter the image items only - return item.type.indexOf("image") !== -1 - }) + const onPaste = useCallback((event: any) => { + // TODO: when sd side panel open, ctrl+v not work + // https://htmldom.dev/paste-an-image-from-the-clipboard/ + if (!event.clipboardData) { + return; + } + const clipboardItems = event.clipboardData.items; + const items: DataTransferItem[] = [].slice + .call(clipboardItems) + .filter((item: DataTransferItem) => { + // Filter the image items only + return item.type.indexOf("image") !== -1; + }); - if (items.length === 0) { - return - } + if (items.length === 0) { + return; + } - event.preventDefault() - event.stopPropagation() + event.preventDefault(); + event.stopPropagation(); - // TODO: add confirm dialog + // TODO: add confirm dialog - const item = items[0] - // Get the blob of image - const blob = item.getAsFile() - if (blob) { - setFile(blob) - } - }, []) + const item = items[0]; + // Get the blob of image + const blob = item.getAsFile(); + if (blob) { + setFile(blob); + } + }, []); - useEffect(() => { - window.addEventListener("dragenter", handleDragIn) - window.addEventListener("dragleave", handleDragOut) - window.addEventListener("dragover", handleDrag) - window.addEventListener("drop", handleDrop) - window.addEventListener("paste", onPaste) - return function cleanUp() { - window.removeEventListener("dragenter", handleDragIn) - window.removeEventListener("dragleave", handleDragOut) - window.removeEventListener("dragover", handleDrag) - window.removeEventListener("drop", handleDrop) - window.removeEventListener("paste", onPaste) - } - }) + useEffect(() => { + window.addEventListener("dragenter", handleDragIn); + window.addEventListener("dragleave", handleDragOut); + window.addEventListener("dragover", handleDrag); + window.addEventListener("drop", handleDrop); + window.addEventListener("paste", onPaste); + return function cleanUp() { + window.removeEventListener("dragenter", handleDragIn); + window.removeEventListener("dragleave", handleDragOut); + window.removeEventListener("dragover", handleDrag); + window.removeEventListener("drop", handleDrop); + window.removeEventListener("paste", onPaste); + }; + }); - return ( -
- -
- - {!file ? ( - { - setFile(f) - }} - /> - ) : ( - <> - )} -
- ) + return ( +
+ +
+ + {!file ? ( + { + setFile(f); + }} + /> + ) : ( + <> + )} +
+ ); } -export default Home +export default Home; diff --git a/custom-demo/front-end/src/components/Editor.tsx b/custom-demo/front-end/src/components/Editor.tsx index 5bd4ef7..f476edc 100644 --- a/custom-demo/front-end/src/components/Editor.tsx +++ b/custom-demo/front-end/src/components/Editor.tsx @@ -1,989 +1,1018 @@ -import { SyntheticEvent, useCallback, useEffect, useRef, useState } from "react" -import { CursorArrowRaysIcon } from "@heroicons/react/24/outline" -import { useToast } from "@/components/ui/use-toast" import { - ReactZoomPanPinchContentRef, - TransformComponent, - TransformWrapper, -} from "react-zoom-pan-pinch" -import { useKeyPressEvent } from "react-use" -import { downloadToOutput, runPlugin } from "@/lib/api" -import { IconButton } from "@/components/ui/button" + SyntheticEvent, + useCallback, + useEffect, + useRef, + useState, +} from "react"; +import { CursorArrowRaysIcon } from "@heroicons/react/24/outline"; +import { useToast } from "@/components/ui/use-toast"; import { - askWritePermission, - cn, - copyCanvasImage, - downloadImage, - drawLines, - generateMask, - isMidClick, - isRightClick, - mouseXY, - srcToFile, -} from "@/lib/utils" -import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react" -import { useImage } from "@/hooks/useImage" -import { Slider } from "./ui/slider" -import { PluginName } from "@/lib/types" -import { useStore } from "@/lib/states" -import Cropper from "./Cropper" -import { InteractiveSegPoints } from "./InteractiveSeg" -import useHotKey from "@/hooks/useHotkey" -import Extender from "./Extender" -import { MAX_BRUSH_SIZE, MIN_BRUSH_SIZE } from "@/lib/const" + ReactZoomPanPinchContentRef, + TransformComponent, + TransformWrapper, +} from "react-zoom-pan-pinch"; +import { useKeyPressEvent } from "react-use"; +import { downloadToOutput, runPlugin } from "@/lib/api"; +import { IconButton } from "@/components/ui/button"; +import { + askWritePermission, + cn, + copyCanvasImage, + downloadImage, + drawLines, + generateMask, + isMidClick, + isRightClick, + mouseXY, + srcToFile, +} from "@/lib/utils"; +import { Eraser, Eye, Redo, Undo, Expand, Download, Send } from "lucide-react"; +import { useImage } from "@/hooks/useImage"; +import { Slider } from "./ui/slider"; +import { PluginName } from "@/lib/types"; +import { useStore } from "@/lib/states"; +import { InteractiveSegPoints } from "./InteractiveSeg"; +import useHotKey from "@/hooks/useHotkey"; +import Extender from "./Extender"; +import { MAX_BRUSH_SIZE, MIN_BRUSH_SIZE } from "@/lib/const"; -const TOOLBAR_HEIGHT = 200 -const COMPARE_SLIDER_DURATION_MS = 300 +const TOOLBAR_HEIGHT = 200; +const COMPARE_SLIDER_DURATION_MS = 300; interface EditorProps { - file: File + file: File; } export default function Editor(props: EditorProps) { - const { file } = props - const { toast } = useToast() - - const [ - disableShortCuts, - windowSize, - isInpainting, - imageWidth, - imageHeight, - settings, - enableAutoSaving, - setImageSize, - setBaseBrushSize, - interactiveSegState, - updateInteractiveSegState, - handleCanvasMouseDown, - handleCanvasMouseMove, - undo, - redo, - undoDisabled, - redoDisabled, - isProcessing, - updateAppState, - runMannually, - runInpainting, - isCropperExtenderResizing, - decreaseBaseBrushSize, - increaseBaseBrushSize, - ] = useStore((state) => [ - state.disableShortCuts, - state.windowSize, - state.isInpainting, - state.imageWidth, - state.imageHeight, - state.settings, - state.serverConfig.enableAutoSaving, - state.setImageSize, - state.setBaseBrushSize, - state.interactiveSegState, - state.updateInteractiveSegState, - state.handleCanvasMouseDown, - state.handleCanvasMouseMove, - state.undo, - state.redo, - state.undoDisabled(), - state.redoDisabled(), - state.getIsProcessing(), - state.updateAppState, - state.runMannually(), - state.runInpainting, - state.isCropperExtenderResizing, - state.decreaseBaseBrushSize, - state.increaseBaseBrushSize, - ]) - const baseBrushSize = useStore((state) => state.editorState.baseBrushSize) - const brushSize = useStore((state) => state.getBrushSize()) - const renders = useStore((state) => state.editorState.renders) - const extraMasks = useStore((state) => state.editorState.extraMasks) - const temporaryMasks = useStore((state) => state.editorState.temporaryMasks) - const lineGroups = useStore((state) => state.editorState.lineGroups) - const curLineGroup = useStore((state) => state.editorState.curLineGroup) - - // Local State - const [showOriginal, setShowOriginal] = useState(false) - const [original, isOriginalLoaded] = useImage(file) - const [context, setContext] = useState() - const [imageContext, setImageContext] = useState() - const [{ x, y }, setCoords] = useState({ x: -1, y: -1 }) - const [showBrush, setShowBrush] = useState(false) - const [showRefBrush, setShowRefBrush] = useState(false) - const [isPanning, setIsPanning] = useState(false) - - const [scale, setScale] = useState(1) - const [panned, setPanned] = useState(false) - const [minScale, setMinScale] = useState(1.0) - const windowCenterX = windowSize.width / 2 - const windowCenterY = windowSize.height / 2 - const viewportRef = useRef(null) - // Indicates that the image has been loaded and is centered on first load - const [initialCentered, setInitialCentered] = useState(false) - - const [isDraging, setIsDraging] = useState(false) - - const [sliderPos, setSliderPos] = useState(0) - const [isChangingBrushSizeByWheel, setIsChangingBrushSizeByWheel] = - useState(false) - - const hadDrawSomething = useCallback(() => { - return curLineGroup.length !== 0 - }, [curLineGroup]) - - useEffect(() => { - if ( - !imageContext || - !isOriginalLoaded || - imageWidth === 0 || - imageHeight === 0 - ) { - return - } - const render = renders.length === 0 ? original : renders[renders.length - 1] - imageContext.canvas.width = imageWidth - imageContext.canvas.height = imageHeight - - imageContext.clearRect( - 0, - 0, - imageContext.canvas.width, - imageContext.canvas.height - ) - imageContext.drawImage(render, 0, 0, imageWidth, imageHeight) - }, [ - renders, - original, - isOriginalLoaded, - imageContext, - imageHeight, - imageWidth, - ]) - - useEffect(() => { - if ( - !context || - !isOriginalLoaded || - imageWidth === 0 || - imageHeight === 0 - ) { - return - } - context.canvas.width = imageWidth - context.canvas.height = imageHeight - context.clearRect(0, 0, context.canvas.width, context.canvas.height) - temporaryMasks.forEach((maskImage) => { - context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) - }) - extraMasks.forEach((maskImage) => { - context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) - }) - - if ( - interactiveSegState.isInteractiveSeg && - interactiveSegState.tmpInteractiveSegMask - ) { - context.drawImage( - interactiveSegState.tmpInteractiveSegMask, - 0, - 0, - imageWidth, - imageHeight - ) - } - drawLines(context, curLineGroup) - }, [ - temporaryMasks, - extraMasks, - isOriginalLoaded, - interactiveSegState, - context, - curLineGroup, - imageHeight, - imageWidth, - ]) - - const getCurrentRender = useCallback(async () => { - let targetFile = file - if (renders.length > 0) { - const lastRender = renders[renders.length - 1] - targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type) - } - return targetFile - }, [file, renders]) - - const hadRunInpainting = () => { - return renders.length !== 0 - } - - const getCurrentWidthHeight = useCallback(() => { - let width = 512 - let height = 512 - if (!isOriginalLoaded) { - return [width, height] - } - if (renders.length === 0) { - width = original.naturalWidth - height = original.naturalHeight - } else if (renders.length !== 0) { - width = renders[renders.length - 1].width - height = renders[renders.length - 1].height - } - - return [width, height] - }, [original, isOriginalLoaded, renders]) - - // Draw once the original image is loaded - useEffect(() => { - if (!isOriginalLoaded) { - return - } - - const [width, height] = getCurrentWidthHeight() - if (width !== imageWidth || height !== imageHeight) { - setImageSize(width, height) - } - - const rW = windowSize.width / width - const rH = (windowSize.height - TOOLBAR_HEIGHT) / height - - let s = 1.0 - if (rW < 1 || rH < 1) { - s = Math.min(rW, rH) - } - setMinScale(s) - setScale(s) - - console.log( - `[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}` - ) - - if (context?.canvas) { - console.log("[on file load] set canvas size") - if (width != context.canvas.width) { - context.canvas.width = width - } - if (height != context.canvas.height) { - context.canvas.height = height - } - } - - if (!initialCentered) { - // 防止每次擦除以后图片 zoom 还原 - viewportRef.current?.centerView(s, 1) - console.log("[on file load] centerView") - setInitialCentered(true) - } - }, [ - viewportRef, - imageHeight, - imageWidth, - original, - isOriginalLoaded, - windowSize, - initialCentered, - getCurrentWidthHeight, - ]) - - useEffect(() => { - console.log("[useEffect] centerView") - // render 改变尺寸以后,undo/redo 重新 center - viewportRef?.current?.centerView(minScale, 1) - }, [imageHeight, imageWidth, viewportRef, minScale]) - - // Zoom reset - const resetZoom = useCallback(() => { - if (!minScale || !windowSize) { - return - } - const viewport = viewportRef.current - if (!viewport) { - return - } - const offsetX = (windowSize.width - imageWidth * minScale) / 2 - const offsetY = (windowSize.height - imageHeight * minScale) / 2 - viewport.setTransform(offsetX, offsetY, minScale, 200, "easeOutQuad") - if (viewport.instance.transformState.scale) { - viewport.instance.transformState.scale = minScale - } - - setScale(minScale) - setPanned(false) - }, [ - viewportRef, - windowSize, - imageHeight, - imageWidth, - windowSize.height, - minScale, - ]) - - useEffect(() => { - window.addEventListener("resize", () => { - resetZoom() - }) - return () => { - window.removeEventListener("resize", () => { - resetZoom() - }) - } - }, [windowSize, resetZoom]) - - const handleEscPressed = () => { - if (isProcessing) { - return - } - - if (isDraging) { - setIsDraging(false) - } else { - resetZoom() - } - } - - useHotKey("Escape", handleEscPressed, [ - isDraging, - isInpainting, - resetZoom, - // drawOnCurrentRender, - ]) - - const onMouseMove = (ev: SyntheticEvent) => { - const mouseEvent = ev.nativeEvent as MouseEvent - setCoords({ x: mouseEvent.pageX, y: mouseEvent.pageY }) - } - - const onMouseDrag = (ev: SyntheticEvent) => { - if (isProcessing) { - return - } - - if (interactiveSegState.isInteractiveSeg) { - return - } - if (isPanning) { - return - } - if (!isDraging) { - return - } - if (curLineGroup.length === 0) { - return - } - - handleCanvasMouseMove(mouseXY(ev)) - } - - const runInteractiveSeg = async (newClicks: number[][]) => { - updateAppState({ isPluginRunning: true }) - const targetFile = await getCurrentRender() - try { - const res = await runPlugin( - true, - PluginName.InteractiveSeg, - targetFile, - undefined, - newClicks - ) - const { blob } = res - const img = new Image() - img.onload = () => { - updateInteractiveSegState({ tmpInteractiveSegMask: img }) - } - img.src = blob - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - } - updateAppState({ isPluginRunning: false }) - } - - const onPointerUp = (ev: SyntheticEvent) => { - if (isMidClick(ev)) { - setIsPanning(false) - return - } - if (!hadDrawSomething()) { - return - } - if (interactiveSegState.isInteractiveSeg) { - return - } - if (isPanning) { - return - } - if (!original.src) { - return - } - const canvas = context?.canvas - if (!canvas) { - return - } - if (isInpainting) { - return - } - if (!isDraging) { - return - } - - if (runMannually) { - setIsDraging(false) - } else { - runInpainting() - } - } - - const onCanvasMouseUp = (ev: SyntheticEvent) => { - if (interactiveSegState.isInteractiveSeg) { - const xy = mouseXY(ev) - const newClicks: number[][] = [...interactiveSegState.clicks] - if (isRightClick(ev)) { - newClicks.push([xy.x, xy.y, 0, newClicks.length]) - } else { - newClicks.push([xy.x, xy.y, 1, newClicks.length]) - } - runInteractiveSeg(newClicks) - updateInteractiveSegState({ clicks: newClicks }) - } - } - - const onMouseDown = (ev: SyntheticEvent) => { - if (isProcessing) { - return - } - if (interactiveSegState.isInteractiveSeg) { - return - } - if (isPanning) { - return - } - if (!isOriginalLoaded) { - return - } - const canvas = context?.canvas - if (!canvas) { - return - } - - if (isRightClick(ev)) { - return - } - - if (isMidClick(ev)) { - setIsPanning(true) - return - } - - setIsDraging(true) - handleCanvasMouseDown(mouseXY(ev)) - } - - const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { - keyboardEvent.preventDefault() - undo() - } - useHotKey("meta+z,ctrl+z", handleUndo) - - const handleRedo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { - keyboardEvent.preventDefault() - redo() - } - useHotKey("shift+ctrl+z,shift+meta+z", handleRedo) - - useKeyPressEvent( - "Tab", - (ev) => { - ev?.preventDefault() - ev?.stopPropagation() - if (hadRunInpainting()) { - setShowOriginal(() => { - window.setTimeout(() => { - setSliderPos(100) - }, 10) - return true - }) - } - }, - (ev) => { - ev?.preventDefault() - ev?.stopPropagation() - if (hadRunInpainting()) { - window.setTimeout(() => { - setSliderPos(0) - }, 10) - window.setTimeout(() => { - setShowOriginal(false) - }, COMPARE_SLIDER_DURATION_MS) - } - } - ) - - const download = useCallback(async () => { - if (file === undefined) { - return - } - if (enableAutoSaving && renders.length > 0) { - try { - await downloadToOutput( - renders[renders.length - 1], - file.name, - file.type - ) - toast({ - description: "Save image success", - }) - } catch (e: any) { - toast({ - variant: "destructive", - title: "Uh oh! Something went wrong.", - description: e.message ? e.message : e.toString(), - }) - } - return - } - - // TODO: download to output directory - const name = file.name.replace(/(\.[\w\d_-]+)$/i, "_cleanup$1") - const curRender = renders[renders.length - 1] - downloadImage(curRender.currentSrc, name) - if (settings.enableDownloadMask) { - let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1") - maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg") - - const maskCanvas = generateMask(imageWidth, imageHeight, lineGroups) - // Create a link - const aDownloadLink = document.createElement("a") - // Add the name of the file to the link - aDownloadLink.download = maskFileName - // Attach the data to the link - aDownloadLink.href = maskCanvas.toDataURL("image/jpeg") - // Get the code to click the download link - aDownloadLink.click() - } - }, [ - file, - enableAutoSaving, - renders, - settings, - imageHeight, - imageWidth, - lineGroups, - ]) - - useHotKey("meta+s,ctrl+s", download) - - const toggleShowBrush = (newState: boolean) => { - if (newState !== showBrush && !isPanning && !isCropperExtenderResizing) { - setShowBrush(newState) - } - } - - const getCursor = useCallback(() => { - if (isProcessing) { - return "default" - } - if (isPanning) { - return "grab" - } - if (showBrush) { - return "none" - } - return undefined - }, [showBrush, isPanning, isProcessing]) - - useHotKey( - "[", - () => { - decreaseBaseBrushSize() - }, - [decreaseBaseBrushSize] - ) - - useHotKey( - "]", - () => { - increaseBaseBrushSize() - }, - [increaseBaseBrushSize] - ) - - // Manual Inpainting Hotkey - useHotKey( - "shift+r", - () => { - if (runMannually && hadDrawSomething()) { - runInpainting() - } - }, - [runMannually, runInpainting, hadDrawSomething] - ) - - useHotKey( - "ctrl+c,meta+c", - async () => { - const hasPermission = await askWritePermission() - if (hasPermission && renders.length > 0) { - if (context?.canvas) { - await copyCanvasImage(context?.canvas) - toast({ - title: "Copy inpainting result to clipboard", - }) - } - } - }, - [renders, context] - ) - - // Toggle clean/zoom tool on spacebar. - useKeyPressEvent( - " ", - (ev) => { - if (!disableShortCuts) { - ev?.preventDefault() - ev?.stopPropagation() - setShowBrush(false) - setIsPanning(true) - } - }, - (ev) => { - if (!disableShortCuts) { - ev?.preventDefault() - ev?.stopPropagation() - setShowBrush(true) - setIsPanning(false) - } - } - ) - - useKeyPressEvent( - "Alt", - (ev) => { - if (!disableShortCuts) { - ev?.preventDefault() - ev?.stopPropagation() - setIsChangingBrushSizeByWheel(true) - } - }, - (ev) => { - if (!disableShortCuts) { - ev?.preventDefault() - ev?.stopPropagation() - setIsChangingBrushSizeByWheel(false) - } - } - ) - - const getCurScale = (): number => { - let s = minScale - if (viewportRef.current?.instance?.transformState.scale !== undefined) { - s = viewportRef.current?.instance?.transformState.scale - } - return s! - } - - const getBrushStyle = (_x: number, _y: number) => { - const curScale = getCurScale() - return { - width: `${brushSize * curScale}px`, - height: `${brushSize * curScale}px`, - left: `${_x}px`, - top: `${_y}px`, - transform: "translate(-50%, -50%)", - } - } - - const renderBrush = (style: any) => { - return ( -
- ) - } - - const handleSliderChange = (value: number) => { - setBaseBrushSize(value) - - if (!showRefBrush) { - setShowRefBrush(true) - window.setTimeout(() => { - setShowRefBrush(false) - }, 10000) - } - } - - const renderInteractiveSegCursor = () => { - return ( -
- -
- ) - } - - const renderCanvas = () => { - return ( - { - if (r) { - viewportRef.current = r - } - }} - panning={{ disabled: !isPanning, velocityDisabled: true }} - wheel={{ step: 0.05, wheelDisabled: isChangingBrushSizeByWheel }} - centerZoomedOut - alignmentAnimation={{ disabled: true }} - centerOnInit - limitToBounds={false} - doubleClick={{ disabled: true }} - initialScale={minScale} - minScale={minScale * 0.3} - onPanning={() => { - if (!panned) { - setPanned(true) - } - }} - onZoom={(ref) => { - setScale(ref.state.scale) - }} - > - -
- { - if (r && !imageContext) { - const ctx = r.getContext("2d") - if (ctx) { - setImageContext(ctx) - } - } - }} - /> - { - e.preventDefault() - }} - onMouseOver={() => { - toggleShowBrush(true) - setShowRefBrush(false) - }} - onFocus={() => toggleShowBrush(true)} - onMouseLeave={() => toggleShowBrush(false)} - onMouseDown={onMouseDown} - onMouseUp={onCanvasMouseUp} - onMouseMove={onMouseDrag} - ref={(r) => { - if (r && !context) { - const ctx = r.getContext("2d") - if (ctx) { - setContext(ctx) - } - } - }} - /> -
- {showOriginal && ( - <> -
- original - - )} -
-
- - - - - - {interactiveSegState.isInteractiveSeg ? ( - - ) : ( - <> - )} - - - ) - } - - const handleScroll = (event: React.WheelEvent) => { - // deltaY 是垂直滚动增量,正值表示向下滚动,负值表示向上滚动 - // deltaX 是水平滚动增量,正值表示向右滚动,负值表示向左滚动 - if (!isChangingBrushSizeByWheel) { - return - } - - const { deltaY } = event - // console.log(`水平滚动增量: ${deltaX}, 垂直滚动增量: ${deltaY}`) - if (deltaY > 0) { - increaseBaseBrushSize() - } else if (deltaY < 0) { - decreaseBaseBrushSize() - } - } - - return ( - - ) + const { file } = props; + const { toast } = useToast(); + + const [ + disableShortCuts, + windowSize, + isInpainting, + imageWidth, + imageHeight, + settings, + enableAutoSaving, + setImageSize, + setBaseBrushSize, + interactiveSegState, + updateInteractiveSegState, + handleCanvasMouseDown, + handleCanvasMouseMove, + undo, + redo, + undoDisabled, + redoDisabled, + isProcessing, + updateAppState, + runMannually, + runInpainting, + submitMaskImage, + isCropperExtenderResizing, + decreaseBaseBrushSize, + increaseBaseBrushSize, + ] = useStore((state) => [ + state.disableShortCuts, + state.windowSize, + state.isInpainting, + state.imageWidth, + state.imageHeight, + state.settings, + state.serverConfig.enableAutoSaving, + state.setImageSize, + state.setBaseBrushSize, + state.interactiveSegState, + state.updateInteractiveSegState, + state.handleCanvasMouseDown, + state.handleCanvasMouseMove, + state.undo, + state.redo, + state.undoDisabled(), + state.redoDisabled(), + state.getIsProcessing(), + state.updateAppState, + state.runMannually(), + state.runInpainting, + state.submitMaskImage, + state.isCropperExtenderResizing, + state.decreaseBaseBrushSize, + state.increaseBaseBrushSize, + ]); + const baseBrushSize = useStore((state) => state.editorState.baseBrushSize); + const brushSize = useStore((state) => state.getBrushSize()); + const renders = useStore((state) => state.editorState.renders); + const extraMasks = useStore((state) => state.editorState.extraMasks); + const temporaryMasks = useStore( + (state) => state.editorState.temporaryMasks, + ); + const lineGroups = useStore((state) => state.editorState.lineGroups); + const curLineGroup = useStore((state) => state.editorState.curLineGroup); + + // Local State + const [showOriginal, setShowOriginal] = useState(false); + const [original, isOriginalLoaded] = useImage(file); + const [context, setContext] = useState(); + const [imageContext, setImageContext] = + useState(); + const [{ x, y }, setCoords] = useState({ x: -1, y: -1 }); + const [showBrush, setShowBrush] = useState(false); + const [showRefBrush, setShowRefBrush] = useState(false); + const [isPanning, setIsPanning] = useState(false); + + const [scale, setScale] = useState(1); + const [panned, setPanned] = useState(false); + const [minScale, setMinScale] = useState(1.0); + const windowCenterX = windowSize.width / 2; + const windowCenterY = windowSize.height / 2; + const viewportRef = useRef(null); + // Indicates that the image has been loaded and is centered on first load + const [initialCentered, setInitialCentered] = useState(false); + + const [isDraging, setIsDraging] = useState(false); + + const [sliderPos, setSliderPos] = useState(0); + const [isChangingBrushSizeByWheel, setIsChangingBrushSizeByWheel] = + useState(false); + + const hadDrawSomething = useCallback(() => { + return curLineGroup.length !== 0; + }, [curLineGroup]); + + useEffect(() => { + if ( + !imageContext || + !isOriginalLoaded || + imageWidth === 0 || + imageHeight === 0 + ) { + return; + } + const render = + renders.length === 0 ? original : renders[renders.length - 1]; + imageContext.canvas.width = imageWidth; + imageContext.canvas.height = imageHeight; + + imageContext.clearRect( + 0, + 0, + imageContext.canvas.width, + imageContext.canvas.height, + ); + imageContext.drawImage(render, 0, 0, imageWidth, imageHeight); + }, [ + renders, + original, + isOriginalLoaded, + imageContext, + imageHeight, + imageWidth, + ]); + + useEffect(() => { + if ( + !context || + !isOriginalLoaded || + imageWidth === 0 || + imageHeight === 0 + ) { + return; + } + context.canvas.width = imageWidth; + context.canvas.height = imageHeight; + context.clearRect(0, 0, context.canvas.width, context.canvas.height); + temporaryMasks.forEach((maskImage) => { + context.drawImage(maskImage, 0, 0, imageWidth, imageHeight); + }); + extraMasks.forEach((maskImage) => { + context.drawImage(maskImage, 0, 0, imageWidth, imageHeight); + }); + + if ( + interactiveSegState.isInteractiveSeg && + interactiveSegState.tmpInteractiveSegMask + ) { + context.drawImage( + interactiveSegState.tmpInteractiveSegMask, + 0, + 0, + imageWidth, + imageHeight, + ); + } + drawLines(context, curLineGroup); + }, [ + temporaryMasks, + extraMasks, + isOriginalLoaded, + interactiveSegState, + context, + curLineGroup, + imageHeight, + imageWidth, + ]); + + const getCurrentRender = useCallback(async () => { + let targetFile = file; + if (renders.length > 0) { + const lastRender = renders[renders.length - 1]; + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type, + ); + } + return targetFile; + }, [file, renders]); + + const hadRunInpainting = () => { + return renders.length !== 0; + }; + + const getCurrentWidthHeight = useCallback(() => { + let width = 512; + let height = 512; + if (!isOriginalLoaded) { + return [width, height]; + } + if (renders.length === 0) { + width = original.naturalWidth; + height = original.naturalHeight; + } else if (renders.length !== 0) { + width = renders[renders.length - 1].width; + height = renders[renders.length - 1].height; + } + + return [width, height]; + }, [original, isOriginalLoaded, renders]); + + // Draw once the original image is loaded + useEffect(() => { + if (!isOriginalLoaded) { + return; + } + + const [width, height] = getCurrentWidthHeight(); + if (width !== imageWidth || height !== imageHeight) { + setImageSize(width, height); + } + + const rW = windowSize.width / width; + const rH = (windowSize.height - TOOLBAR_HEIGHT) / height; + + let s = 1.0; + if (rW < 1 || rH < 1) { + s = Math.min(rW, rH); + } + setMinScale(s); + setScale(s); + + console.log( + `[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}`, + ); + + if (context?.canvas) { + console.log("[on file load] set canvas size"); + if (width != context.canvas.width) { + context.canvas.width = width; + } + if (height != context.canvas.height) { + context.canvas.height = height; + } + } + + if (!initialCentered) { + // 防止每次擦除以后图片 zoom 还原 + viewportRef.current?.centerView(s, 1); + console.log("[on file load] centerView"); + setInitialCentered(true); + } + }, [ + viewportRef, + imageHeight, + imageWidth, + original, + isOriginalLoaded, + windowSize, + initialCentered, + getCurrentWidthHeight, + ]); + + useEffect(() => { + console.log("[useEffect] centerView"); + // render 改变尺寸以后,undo/redo 重新 center + viewportRef?.current?.centerView(minScale, 1); + }, [imageHeight, imageWidth, viewportRef, minScale]); + + // Zoom reset + const resetZoom = useCallback(() => { + if (!minScale || !windowSize) { + return; + } + const viewport = viewportRef.current; + if (!viewport) { + return; + } + const offsetX = (windowSize.width - imageWidth * minScale) / 2; + const offsetY = (windowSize.height - imageHeight * minScale) / 2; + viewport.setTransform(offsetX, offsetY, minScale, 200, "easeOutQuad"); + if (viewport.instance.transformState.scale) { + viewport.instance.transformState.scale = minScale; + } + + setScale(minScale); + setPanned(false); + }, [ + viewportRef, + windowSize, + imageHeight, + imageWidth, + windowSize.height, + minScale, + ]); + + useEffect(() => { + window.addEventListener("resize", () => { + resetZoom(); + }); + return () => { + window.removeEventListener("resize", () => { + resetZoom(); + }); + }; + }, [windowSize, resetZoom]); + + const handleEscPressed = () => { + if (isProcessing) { + return; + } + + if (isDraging) { + setIsDraging(false); + } else { + resetZoom(); + } + }; + + useHotKey("Escape", handleEscPressed, [ + isDraging, + isInpainting, + resetZoom, + // drawOnCurrentRender, + ]); + + const onMouseMove = (ev: SyntheticEvent) => { + const mouseEvent = ev.nativeEvent as MouseEvent; + setCoords({ x: mouseEvent.pageX, y: mouseEvent.pageY }); + }; + + const onMouseDrag = (ev: SyntheticEvent) => { + if (isProcessing) { + return; + } + + if (interactiveSegState.isInteractiveSeg) { + return; + } + if (isPanning) { + return; + } + if (!isDraging) { + return; + } + if (curLineGroup.length === 0) { + return; + } + + handleCanvasMouseMove(mouseXY(ev)); + }; + + const runInteractiveSeg = async (newClicks: number[][]) => { + updateAppState({ isPluginRunning: true }); + const targetFile = await getCurrentRender(); + try { + const res = await runPlugin( + true, + PluginName.InteractiveSeg, + targetFile, + undefined, + newClicks, + ); + const { blob } = res; + const img = new Image(); + img.onload = () => { + updateInteractiveSegState({ tmpInteractiveSegMask: img }); + }; + img.src = blob; + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }); + } + updateAppState({ isPluginRunning: false }); + }; + + const onPointerUp = (ev: SyntheticEvent) => { + if (isMidClick(ev)) { + setIsPanning(false); + return; + } + if (!hadDrawSomething()) { + return; + } + if (interactiveSegState.isInteractiveSeg) { + return; + } + if (isPanning) { + return; + } + if (!original.src) { + return; + } + const canvas = context?.canvas; + if (!canvas) { + return; + } + if (isInpainting) { + return; + } + if (!isDraging) { + return; + } + + if (runMannually) { + setIsDraging(false); + } + // else { + // runInpainting() + // } + }; + + const onCanvasMouseUp = (ev: SyntheticEvent) => { + setIsDraging(false); + if (interactiveSegState.isInteractiveSeg) { + const xy = mouseXY(ev); + const newClicks: number[][] = [...interactiveSegState.clicks]; + if (isRightClick(ev)) { + newClicks.push([xy.x, xy.y, 0, newClicks.length]); + } else { + newClicks.push([xy.x, xy.y, 1, newClicks.length]); + } + runInteractiveSeg(newClicks); + updateInteractiveSegState({ clicks: newClicks }); + } + }; + + const onMouseDown = (ev: SyntheticEvent) => { + if (isProcessing) { + return; + } + if (interactiveSegState.isInteractiveSeg) { + return; + } + if (isPanning) { + return; + } + if (!isOriginalLoaded) { + return; + } + const canvas = context?.canvas; + if (!canvas) { + return; + } + + if (isRightClick(ev)) { + return; + } + + if (isMidClick(ev)) { + setIsPanning(true); + return; + } + + setIsDraging(true); + handleCanvasMouseDown(mouseXY(ev)); + }; + + const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { + keyboardEvent.preventDefault(); + undo(); + }; + useHotKey("meta+z,ctrl+z", handleUndo); + + const handleRedo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { + keyboardEvent.preventDefault(); + redo(); + }; + useHotKey("shift+ctrl+z,shift+meta+z", handleRedo); + + useKeyPressEvent( + "Tab", + (ev) => { + ev?.preventDefault(); + ev?.stopPropagation(); + if (hadRunInpainting()) { + setShowOriginal(() => { + window.setTimeout(() => { + setSliderPos(100); + }, 10); + return true; + }); + } + }, + (ev) => { + ev?.preventDefault(); + ev?.stopPropagation(); + if (hadRunInpainting()) { + window.setTimeout(() => { + setSliderPos(0); + }, 10); + window.setTimeout(() => { + setShowOriginal(false); + }, COMPARE_SLIDER_DURATION_MS); + } + }, + ); + + const download = useCallback(async () => { + if (file === undefined) { + return; + } + if (enableAutoSaving && renders.length > 0) { + try { + await downloadToOutput( + renders[renders.length - 1], + file.name, + file.type, + ); + toast({ + description: "Save image success", + }); + } catch (e: any) { + toast({ + variant: "destructive", + title: "Uh oh! Something went wrong.", + description: e.message ? e.message : e.toString(), + }); + } + return; + } + + // TODO: download to output directory + const name = file.name.replace(/(\.[\w\d_-]+)$/i, "_cleanup$1"); + const curRender = renders[renders.length - 1]; + downloadImage(curRender.currentSrc, name); + if (settings.enableDownloadMask) { + let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1"); + maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg"); + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + lineGroups, + ); + // Create a link + const aDownloadLink = document.createElement("a"); + // Add the name of the file to the link + aDownloadLink.download = maskFileName; + // Attach the data to the link + aDownloadLink.href = maskCanvas.toDataURL("image/jpeg"); + // Get the code to click the download link + aDownloadLink.click(); + } + }, [ + file, + enableAutoSaving, + renders, + settings, + imageHeight, + imageWidth, + lineGroups, + ]); + + useHotKey("meta+s,ctrl+s", download); + + const toggleShowBrush = (newState: boolean) => { + if ( + newState !== showBrush && + !isPanning && + !isCropperExtenderResizing + ) { + setShowBrush(newState); + } + }; + + const getCursor = useCallback(() => { + if (isProcessing) { + return "default"; + } + if (isPanning) { + return "grab"; + } + if (showBrush) { + return "none"; + } + return undefined; + }, [showBrush, isPanning, isProcessing]); + + useHotKey( + "[", + () => { + decreaseBaseBrushSize(); + }, + [decreaseBaseBrushSize], + ); + + useHotKey( + "]", + () => { + increaseBaseBrushSize(); + }, + [increaseBaseBrushSize], + ); + + // Manual Inpainting Hotkey + useHotKey( + "shift+r", + () => { + if (runMannually && hadDrawSomething()) { + runInpainting(); + } + }, + [runMannually, runInpainting, hadDrawSomething], + ); + + useHotKey( + "ctrl+c,meta+c", + async () => { + const hasPermission = await askWritePermission(); + if (hasPermission && renders.length > 0) { + if (context?.canvas) { + await copyCanvasImage(context?.canvas); + toast({ + title: "Copy inpainting result to clipboard", + }); + } + } + }, + [renders, context], + ); + + // Toggle clean/zoom tool on spacebar. + useKeyPressEvent( + " ", + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault(); + ev?.stopPropagation(); + setShowBrush(false); + setIsPanning(true); + } + }, + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault(); + ev?.stopPropagation(); + setShowBrush(true); + setIsPanning(false); + } + }, + ); + + useKeyPressEvent( + "Alt", + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault(); + ev?.stopPropagation(); + setIsChangingBrushSizeByWheel(true); + } + }, + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault(); + ev?.stopPropagation(); + setIsChangingBrushSizeByWheel(false); + } + }, + ); + + const getCurScale = (): number => { + let s = minScale; + if (viewportRef.current?.instance?.transformState.scale !== undefined) { + s = viewportRef.current?.instance?.transformState.scale; + } + return s!; + }; + + const getBrushStyle = (_x: number, _y: number) => { + const curScale = getCurScale(); + return { + width: `${brushSize * curScale}px`, + height: `${brushSize * curScale}px`, + left: `${_x}px`, + top: `${_y}px`, + transform: "translate(-50%, -50%)", + }; + }; + + const renderBrush = (style: any) => { + return ( +
+ ); + }; + + const handleSliderChange = (value: number) => { + setBaseBrushSize(value); + + if (!showRefBrush) { + setShowRefBrush(true); + window.setTimeout(() => { + setShowRefBrush(false); + }, 10000); + } + }; + + const renderInteractiveSegCursor = () => { + return ( +
+ +
+ ); + }; + + const renderCanvas = () => { + return ( + { + if (r) { + viewportRef.current = r; + } + }} + panning={{ disabled: !isPanning, velocityDisabled: true }} + wheel={{ + step: 0.05, + wheelDisabled: isChangingBrushSizeByWheel, + }} + centerZoomedOut + alignmentAnimation={{ disabled: true }} + centerOnInit + limitToBounds={false} + doubleClick={{ disabled: true }} + initialScale={minScale} + minScale={minScale * 0.3} + onPanning={() => { + if (!panned) { + setPanned(true); + } + }} + onZoom={(ref) => { + setScale(ref.state.scale); + }} + > + +
+ { + if (r && !imageContext) { + const ctx = r.getContext("2d"); + if (ctx) { + setImageContext(ctx); + } + } + }} + /> + { + e.preventDefault(); + }} + onMouseOver={() => { + toggleShowBrush(true); + setShowRefBrush(false); + }} + onFocus={() => toggleShowBrush(true)} + onMouseLeave={() => toggleShowBrush(false)} + onMouseDown={onMouseDown} + onMouseUp={onCanvasMouseUp} + onMouseMove={onMouseDrag} + ref={(r) => { + if (r && !context) { + const ctx = r.getContext("2d"); + if (ctx) { + setContext(ctx); + } + } + }} + /> +
+ {showOriginal && ( + <> +
+ original + + )} +
+
+ + + + {interactiveSegState.isInteractiveSeg ? ( + + ) : ( + <> + )} + + + ); + }; + + const handleScroll = (event: React.WheelEvent) => { + // deltaY 是垂直滚动增量,正值表示向下滚动,负值表示向上滚动 + // deltaX 是水平滚动增量,正值表示向右滚动,负值表示向左滚动 + if (!isChangingBrushSizeByWheel) { + return; + } + + const { deltaY } = event; + // console.log(`水平滚动增量: ${deltaX}, 垂直滚动增量: ${deltaY}`) + if (deltaY > 0) { + increaseBaseBrushSize(); + } else if (deltaY < 0) { + decreaseBaseBrushSize(); + } + }; + + return ( + + ); } diff --git a/custom-demo/front-end/src/components/FileSelect.tsx b/custom-demo/front-end/src/components/FileSelect.tsx index c75da22..84f120e 100644 --- a/custom-demo/front-end/src/components/FileSelect.tsx +++ b/custom-demo/front-end/src/components/FileSelect.tsx @@ -1,71 +1,71 @@ -import { useState } from "react" -import useResolution from "@/hooks/useResolution" +import { useState } from "react"; +import useResolution from "@/hooks/useResolution"; type FileSelectProps = { - onSelection: (file: File) => void -} + onSelection: (file: File) => void; +}; export default function FileSelect(props: FileSelectProps) { - const { onSelection } = props + const { onSelection } = props; - const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`) + const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`); - const resolution = useResolution() + const resolution = useResolution(); - function onFileSelected(file: File) { - if (!file) { - return - } - // Skip non-image files - const isImage = file.type.match("image.*") - if (!isImage) { - return - } - try { - // Check if file is larger than 20mb - if (file.size > 20 * 1024 * 1024) { - throw new Error("file too large") - } - onSelection(file) - } catch (e) { - // eslint-disable-next-line - alert(`error: ${(e as any).message}`) - } - } + function onFileSelected(file: File) { + if (!file) { + return; + } + // Skip non-image files + const isImage = file.type.match("image.*"); + if (!isImage) { + return; + } + try { + // Check if file is larger than 20mb + if (file.size > 20 * 1024 * 1024) { + throw new Error("file too large"); + } + onSelection(file); + } catch (e) { + // eslint-disable-next-line + alert(`error: ${(e as any).message}`); + } + } - return ( -
- -
- ) + return ( +
+ +
+ ); } diff --git a/custom-demo/front-end/src/components/Header.tsx b/custom-demo/front-end/src/components/Header.tsx index 05acce6..da33f1e 100644 --- a/custom-demo/front-end/src/components/Header.tsx +++ b/custom-demo/front-end/src/components/Header.tsx @@ -1,198 +1,73 @@ -import { PlayIcon } from "@radix-ui/react-icons" -import { useState } from "react" -import { IconButton, ImageUploadButton } from "@/components/ui/button" -import Shortcuts from "@/components/Shortcuts" -import { useImage } from "@/hooks/useImage" +import { IconButton, ImageUploadButton } from "@/components/ui/button"; -import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover" -import PromptInput from "./PromptInput" -import { RotateCw, Image, Upload } from "lucide-react" -import FileManager from "./FileManager" -import { getMediaFile } from "@/lib/api" -import { useStore } from "@/lib/states" -import SettingsDialog from "./Settings" -import { cn, fileToImage } from "@/lib/utils" -import Coffee from "./Coffee" -import { useToast } from "./ui/use-toast" +import PromptInput from "./PromptInput"; +import { RotateCw, Image } from "lucide-react"; +import { useStore } from "@/lib/states"; const Header = () => { - const [ - file, - customMask, - isInpainting, - serverConfig, - runMannually, - enableUploadMask, - model, - setFile, - setCustomFile, - runInpainting, - showPrevMask, - hidePrevMask, - imageHeight, - imageWidth, - ] = useStore((state) => [ - state.file, - state.customMask, - state.isInpainting, - state.serverConfig, - state.runMannually(), - state.settings.enableUploadMask, - state.settings.model, - state.setFile, - state.setCustomFile, - state.runInpainting, - state.showPrevMask, - state.hidePrevMask, - state.imageHeight, - state.imageWidth, - ]) + const [ + file, + isInpainting, + model, + setFile, + runInpainting, + showPrevMask, + hidePrevMask, + ] = useStore((state) => [ + state.file, + state.isInpainting, + state.settings.model, + state.setFile, + state.runInpainting, + state.showPrevMask, + state.hidePrevMask, + ]); - const { toast } = useToast() - const [maskImage, maskImageLoaded] = useImage(customMask) - const [openMaskPopover, setOpenMaskPopover] = useState(false) + const handleRerunLastMask = () => { + runInpainting(); + }; - const handleRerunLastMask = () => { - runInpainting() - } + const onRerunMouseEnter = () => { + showPrevMask(); + }; - const onRerunMouseEnter = () => { - showPrevMask() - } + const onRerunMouseLeave = () => { + hidePrevMask(); + }; - const onRerunMouseLeave = () => { - hidePrevMask() - } + return ( +
+
+ { + setFile(file); + }} + > + + - return ( -
-
- {serverConfig.enableFileManager ? ( - { - try { - const newFile = await getMediaFile(tab, filename) - setFile(newFile) - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - return - } - }} - /> - ) : ( - <> - )} + {file && !model.need_prompt ? ( + + + + ) : ( + <> + )} +
- { - setFile(file) - }} - > - - + {model.need_prompt ? : <>} -
- { - let newCustomMask: HTMLImageElement | null = null - try { - newCustomMask = await fileToImage(file) - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - return - } - if ( - newCustomMask.naturalHeight !== imageHeight || - newCustomMask.naturalWidth !== imageWidth - ) { - toast({ - variant: "destructive", - description: `The size of the mask must same as image: ${imageWidth}x${imageHeight}`, - }) - return - } +
+
+ ); +}; - setCustomFile(file) - if (!runMannually) { - runInpainting() - } - }} - > - - - - {customMask ? ( - - setOpenMaskPopover(true)} - onMouseLeave={() => setOpenMaskPopover(false)} - style={{ - visibility: customMask ? "visible" : "hidden", - outline: "none", - }} - onClick={() => { - if (customMask) { - } - }} - > - - - - - - {maskImageLoaded ? ( - Custom mask - ) : ( - <> - )} - - - ) : ( - <> - )} -
- - {file && !model.need_prompt ? ( - - - - ) : ( - <> - )} -
- - {model.need_prompt ? : <>} - -
- - - {serverConfig.disableModelSwitch ? <> : } -
- - ) -} - -export default Header +export default Header; diff --git a/custom-demo/front-end/src/components/Workspace.tsx b/custom-demo/front-end/src/components/Workspace.tsx index cba0a15..fd10c97 100644 --- a/custom-demo/front-end/src/components/Workspace.tsx +++ b/custom-demo/front-end/src/components/Workspace.tsx @@ -1,39 +1,31 @@ -import { useEffect } from "react" -import Editor from "./Editor" -import { currentModel } from "@/lib/api" -import { useStore } from "@/lib/states" -import ImageSize from "./ImageSize" -import Plugins from "./Plugins" -import { InteractiveSeg } from "./InteractiveSeg" -import SidePanel from "./SidePanel" -import DiffusionProgress from "./DiffusionProgress" +import { useEffect } from "react"; +import Editor from "./Editor"; +import { currentModel } from "@/lib/api"; +import { useStore } from "@/lib/states"; +import ImageSize from "./ImageSize"; const Workspace = () => { - const [file, updateSettings] = useStore((state) => [ - state.file, - state.updateSettings, - ]) + const [file, updateSettings] = useStore((state) => [ + state.file, + state.updateSettings, + ]); - useEffect(() => { - const fetchCurrentModel = async () => { - const model = await currentModel() - updateSettings({ model }) - } - fetchCurrentModel() - }, []) + useEffect(() => { + const fetchCurrentModel = async () => { + const model = await currentModel(); + updateSettings({ model }); + }; + fetchCurrentModel(); + }, []); - return ( - <> -
- - -
- - - - {file ? : <>} - - ) -} + return ( + <> +
+ +
+ {file ? : <>} + + ); +}; -export default Workspace +export default Workspace; diff --git a/custom-demo/front-end/src/lib/api.ts b/custom-demo/front-end/src/lib/api.ts index deb5864..ef3026c 100644 --- a/custom-demo/front-end/src/lib/api.ts +++ b/custom-demo/front-end/src/lib/api.ts @@ -1,231 +1,256 @@ import { - Filename, - GenInfo, - ModelInfo, - PowerPaintTask, - Rect, - ServerConfig, + Filename, + GenInfo, + ModelInfo, + PowerPaintTask, + Rect, + ServerConfig, } from "@/lib/types"; import { Settings } from "@/lib/states"; import { convertToBase64, srcToFile } from "@/lib/utils"; import axios from "axios"; export const API_ENDPOINT = import.meta.env.DEV - ? import.meta.env.VITE_BACKEND + "/api/v1" - : "/api/v1"; + ? import.meta.env.VITE_BACKEND + "/api/v1" + : "/api/v1"; const api = axios.create({ - baseURL: API_ENDPOINT, + baseURL: API_ENDPOINT, }); export default async function inpaint( - imageFile: File, - settings: Settings, - croperRect: Rect, - extenderState: Rect, - mask: File | Blob, - paintByExampleImage: File | null = null + imageFile: File, + settings: Settings, + croperRect: Rect, + extenderState: Rect, + mask: File | Blob, + paintByExampleImage: File | null = null, ) { - const imageBase64 = await convertToBase64(imageFile); - const maskBase64 = await convertToBase64(mask); - const exampleImageBase64 = paintByExampleImage - ? await convertToBase64(paintByExampleImage) - : null; + const imageBase64 = await convertToBase64(imageFile); + const maskBase64 = await convertToBase64(mask); + const exampleImageBase64 = paintByExampleImage + ? await convertToBase64(paintByExampleImage) + : null; - const res = await fetch(`${API_ENDPOINT}/inpaint`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - image: imageBase64, - mask: maskBase64, - ldm_steps: settings.ldmSteps, - ldm_sampler: settings.ldmSampler, - zits_wireframe: settings.zitsWireframe, - cv2_flag: settings.cv2Flag, - cv2_radius: settings.cv2Radius, - hd_strategy: "Crop", - hd_strategy_crop_triger_size: 640, - hd_strategy_crop_margin: 128, - hd_trategy_resize_imit: 2048, - prompt: settings.prompt, - negative_prompt: settings.negativePrompt, - use_croper: settings.showCropper, - croper_x: croperRect.x, - croper_y: croperRect.y, - croper_height: croperRect.height, - croper_width: croperRect.width, - use_extender: settings.showExtender, - extender_x: extenderState.x, - extender_y: extenderState.y, - extender_height: extenderState.height, - extender_width: extenderState.width, - sd_mask_blur: settings.sdMaskBlur, - sd_strength: settings.sdStrength, - sd_steps: settings.sdSteps, - sd_guidance_scale: settings.sdGuidanceScale, - sd_sampler: settings.sdSampler, - sd_seed: settings.seedFixed ? settings.seed : -1, - sd_match_histograms: settings.sdMatchHistograms, - sd_freeu: settings.enableFreeu, - sd_freeu_config: settings.freeuConfig, - sd_lcm_lora: settings.enableLCMLora, - paint_by_example_example_image: exampleImageBase64, - p2p_image_guidance_scale: settings.p2pImageGuidanceScale, - enable_controlnet: settings.enableControlnet, - controlnet_conditioning_scale: settings.controlnetConditioningScale, - controlnet_method: settings.controlnetMethod - ? settings.controlnetMethod - : "", - powerpaint_task: settings.showExtender - ? PowerPaintTask.outpainting - : settings.powerpaintTask, - }), - }); - if (res.ok) { - const blob = await res.blob(); - return { - blob: URL.createObjectURL(blob), - seed: res.headers.get("X-Seed"), - }; - } - const errors = await res.json(); - throw new Error(`Something went wrong: ${errors.errors}`); + const res = await fetch(`${API_ENDPOINT}/inpaint`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + image: imageBase64, + mask: maskBase64, + ldm_steps: settings.ldmSteps, + ldm_sampler: settings.ldmSampler, + zits_wireframe: settings.zitsWireframe, + cv2_flag: settings.cv2Flag, + cv2_radius: settings.cv2Radius, + hd_strategy: "Crop", + hd_strategy_crop_triger_size: 640, + hd_strategy_crop_margin: 128, + hd_trategy_resize_imit: 2048, + prompt: settings.prompt, + negative_prompt: settings.negativePrompt, + use_croper: settings.showCropper, + croper_x: croperRect.x, + croper_y: croperRect.y, + croper_height: croperRect.height, + croper_width: croperRect.width, + use_extender: settings.showExtender, + extender_x: extenderState.x, + extender_y: extenderState.y, + extender_height: extenderState.height, + extender_width: extenderState.width, + sd_mask_blur: settings.sdMaskBlur, + sd_strength: settings.sdStrength, + sd_steps: settings.sdSteps, + sd_guidance_scale: settings.sdGuidanceScale, + sd_sampler: settings.sdSampler, + sd_seed: settings.seedFixed ? settings.seed : -1, + sd_match_histograms: settings.sdMatchHistograms, + sd_freeu: settings.enableFreeu, + sd_freeu_config: settings.freeuConfig, + sd_lcm_lora: settings.enableLCMLora, + paint_by_example_example_image: exampleImageBase64, + p2p_image_guidance_scale: settings.p2pImageGuidanceScale, + enable_controlnet: settings.enableControlnet, + controlnet_conditioning_scale: settings.controlnetConditioningScale, + controlnet_method: settings.controlnetMethod + ? settings.controlnetMethod + : "", + powerpaint_task: settings.showExtender + ? PowerPaintTask.outpainting + : settings.powerpaintTask, + }), + }); + if (res.ok) { + const blob = await res.blob(); + return { + blob: URL.createObjectURL(blob), + seed: res.headers.get("X-Seed"), + }; + } + const errors = await res.json(); + throw new Error(`Something went wrong: ${errors.errors}`); } export async function getServerConfig(): Promise { - const res = await api.get(`/server-config`); - return res.data; + const res = await api.get(`/server-config`); + return res.data; } export async function switchModel(name: string): Promise { - const res = await api.post(`/model`, { name }); - return res.data; + const res = await api.post(`/model`, { name }); + return res.data; } export async function switchPluginModel( - plugin_name: string, - model_name: string + plugin_name: string, + model_name: string, ) { - return api.post(`/switch_plugin_model`, { plugin_name, model_name }); + return api.post(`/switch_plugin_model`, { plugin_name, model_name }); } export async function currentModel(): Promise { - const res = await api.get("/model"); - return res.data; + const res = await api.get("/model"); + return res.data; } export async function runPlugin( - genMask: boolean, - name: string, - imageFile: File, - upscale?: number, - clicks?: number[][] + genMask: boolean, + name: string, + imageFile: File, + upscale?: number, + clicks?: number[][], ) { - const imageBase64 = await convertToBase64(imageFile); - const p = genMask ? "run_plugin_gen_mask" : "run_plugin_gen_image"; - const res = await fetch(`${API_ENDPOINT}/${p}`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - name, - image: imageBase64, - upscale, - clicks, - }), - }); - if (res.ok) { - const blob = await res.blob(); - return { blob: URL.createObjectURL(blob) }; - } - const errMsg = await res.json(); - throw new Error(errMsg); + const imageBase64 = await convertToBase64(imageFile); + const p = genMask ? "run_plugin_gen_mask" : "run_plugin_gen_image"; + const res = await fetch(`${API_ENDPOINT}/${p}`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + name, + image: imageBase64, + upscale, + clicks, + }), + }); + if (res.ok) { + const blob = await res.blob(); + return { blob: URL.createObjectURL(blob) }; + } + const errMsg = await res.json(); + throw new Error(errMsg); } export async function getMediaFile(tab: string, filename: string) { - const res = await fetch( - `${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent( - filename - )}`, - { - method: "GET", - } - ); - if (res.ok) { - const blob = await res.blob(); - const file = new File([blob], filename, { - type: res.headers.get("Content-Type") ?? "image/png", - }); - return file; - } - const errMsg = await res.json(); - throw new Error(errMsg.errors); + const res = await fetch( + `${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent( + filename, + )}`, + { + method: "GET", + }, + ); + if (res.ok) { + const blob = await res.blob(); + const file = new File([blob], filename, { + type: res.headers.get("Content-Type") ?? "image/png", + }); + return file; + } + const errMsg = await res.json(); + throw new Error(errMsg.errors); } export async function getMedias(tab: string): Promise { - const res = await api.get(`medias`, { params: { tab } }); - return res.data; + const res = await api.get(`medias`, { params: { tab } }); + return res.data; } export async function downloadToOutput( - image: HTMLImageElement, - filename: string, - mimeType: string + image: HTMLImageElement, + filename: string, + mimeType: string, ) { - const file = await srcToFile(image.src, filename, mimeType); - const fd = new FormData(); - fd.append("file", file); + const file = await srcToFile(image.src, filename, mimeType); + const fd = new FormData(); + fd.append("file", file); - try { - const res = await fetch(`${API_ENDPOINT}/save_image`, { - method: "POST", - body: fd, - }); - if (!res.ok) { - const errMsg = await res.text(); - throw new Error(errMsg); - } - } catch (error) { - throw new Error(`Something went wrong: ${error}`); - } + try { + const res = await fetch(`${API_ENDPOINT}/save_image`, { + method: "POST", + body: fd, + }); + if (!res.ok) { + const errMsg = await res.text(); + throw new Error(errMsg); + } + } catch (error) { + throw new Error(`Something went wrong: ${error}`); + } } export async function getGenInfo(file: File): Promise { - const fd = new FormData(); - fd.append("file", file); - const res = await api.post(`/gen-info`, fd); - return res.data; + const fd = new FormData(); + fd.append("file", file); + const res = await api.post(`/gen-info`, fd); + return res.data; } export async function getSamplers(): Promise { - const res = await api.post("/samplers"); - return res.data; + const res = await api.post("/samplers"); + return res.data; } export async function postAdjustMask( - mask: File | Blob, - operate: "expand" | "shrink" | "reverse", - kernel_size: number + mask: File | Blob, + operate: "expand" | "shrink" | "reverse", + kernel_size: number, ) { - const maskBase64 = await convertToBase64(mask); - const res = await fetch(`${API_ENDPOINT}/adjust_mask`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - mask: maskBase64, - operate: operate, - kernel_size: kernel_size, - }), - }); - if (res.ok) { - const blob = await res.blob(); - return blob; - } - const errMsg = await res.json(); - throw new Error(errMsg); + const maskBase64 = await convertToBase64(mask); + const res = await fetch(`${API_ENDPOINT}/adjust_mask`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + mask: maskBase64, + operate: operate, + kernel_size: kernel_size, + }), + }); + if (res.ok) { + const blob = await res.blob(); + return blob; + } + const errMsg = await res.json(); + throw new Error(errMsg); +} + +export async function submitMask(imageFile: File, mask: File | Blob) { + const imageBase64 = await convertToBase64(imageFile); + const maskBase64 = await convertToBase64(mask); + + const res = await fetch(`${API_ENDPOINT}/submit-mask`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + image: imageBase64, + mask: maskBase64, + }), + }); + if (res.ok) { + const blob = await res.blob(); + return { + blob: URL.createObjectURL(blob), + seed: res.headers.get("X-Seed"), + }; + } + // const errors = await res.json(); + throw new Error(`Submit successfull.`); } diff --git a/custom-demo/front-end/src/lib/states.ts b/custom-demo/front-end/src/lib/states.ts index faba1ba..cac52b8 100644 --- a/custom-demo/front-end/src/lib/states.ts +++ b/custom-demo/front-end/src/lib/states.ts @@ -1,1089 +1,1250 @@ -import { persist } from "zustand/middleware" -import { shallow } from "zustand/shallow" -import { immer } from "zustand/middleware/immer" -import { castDraft } from "immer" -import { createWithEqualityFn } from "zustand/traditional" +import { persist } from "zustand/middleware"; +import { shallow } from "zustand/shallow"; +import { immer } from "zustand/middleware/immer"; +import { castDraft } from "immer"; +import { createWithEqualityFn } from "zustand/traditional"; import { - AdjustMaskOperate, - CV2Flag, - ExtenderDirection, - FreeuConfig, - LDMSampler, - Line, - LineGroup, - ModelInfo, - PluginParams, - Point, - PowerPaintTask, - ServerConfig, - Size, - SortBy, - SortOrder, -} from "./types" + AdjustMaskOperate, + CV2Flag, + ExtenderDirection, + FreeuConfig, + LDMSampler, + Line, + LineGroup, + ModelInfo, + PluginParams, + Point, + PowerPaintTask, + ServerConfig, + Size, + SortBy, + SortOrder, +} from "./types"; import { - BRUSH_COLOR, - DEFAULT_BRUSH_SIZE, - DEFAULT_NEGATIVE_PROMPT, - MAX_BRUSH_SIZE, - MODEL_TYPE_INPAINT, - PAINT_BY_EXAMPLE, -} from "./const" + BRUSH_COLOR, + DEFAULT_BRUSH_SIZE, + DEFAULT_NEGATIVE_PROMPT, + MAX_BRUSH_SIZE, + MODEL_TYPE_INPAINT, + PAINT_BY_EXAMPLE, +} from "./const"; import { - blobToImage, - canvasToImage, - dataURItoBlob, - generateMask, - loadImage, - srcToFile, -} from "./utils" -import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api" -import { toast } from "@/components/ui/use-toast" + blobToImage, + canvasToImage, + dataURItoBlob, + generateMask, + loadImage, + srcToFile, +} from "./utils"; +import inpaint, { + getGenInfo, + postAdjustMask, + runPlugin, + submitMask, +} from "./api"; +import { toast } from "@/components/ui/use-toast"; type FileManagerState = { - sortBy: SortBy - sortOrder: SortOrder - layout: "rows" | "masonry" - searchText: string - inputDirectory: string - outputDirectory: string -} + sortBy: SortBy; + sortOrder: SortOrder; + layout: "rows" | "masonry"; + searchText: string; + inputDirectory: string; + outputDirectory: string; +}; type CropperState = { - x: number - y: number - width: number - height: number -} + x: number; + y: number; + width: number; + height: number; +}; export type Settings = { - model: ModelInfo - enableDownloadMask: boolean - enableManualInpainting: boolean - enableUploadMask: boolean - enableAutoExtractPrompt: boolean - showCropper: boolean - showExtender: boolean - extenderDirection: ExtenderDirection + model: ModelInfo; + enableDownloadMask: boolean; + enableManualInpainting: boolean; + enableUploadMask: boolean; + enableAutoExtractPrompt: boolean; + showCropper: boolean; + showExtender: boolean; + extenderDirection: ExtenderDirection; - // For LDM - ldmSteps: number - ldmSampler: LDMSampler + // For LDM + ldmSteps: number; + ldmSampler: LDMSampler; - // For ZITS - zitsWireframe: boolean + // For ZITS + zitsWireframe: boolean; - // For OpenCV2 - cv2Radius: number - cv2Flag: CV2Flag + // For OpenCV2 + cv2Radius: number; + cv2Flag: CV2Flag; - // For Diffusion moel - prompt: string - negativePrompt: string - seed: number - seedFixed: boolean + // For Diffusion moel + prompt: string; + negativePrompt: string; + seed: number; + seedFixed: boolean; - // For SD - sdMaskBlur: number - sdStrength: number - sdSteps: number - sdGuidanceScale: number - sdSampler: string - sdMatchHistograms: boolean - sdScale: number + // For SD + sdMaskBlur: number; + sdStrength: number; + sdSteps: number; + sdGuidanceScale: number; + sdSampler: string; + sdMatchHistograms: boolean; + sdScale: number; - // Pix2Pix - p2pImageGuidanceScale: number + // Pix2Pix + p2pImageGuidanceScale: number; - // ControlNet - enableControlnet: boolean - controlnetConditioningScale: number - controlnetMethod: string + // ControlNet + enableControlnet: boolean; + controlnetConditioningScale: number; + controlnetMethod: string; - enableLCMLora: boolean - enableFreeu: boolean - freeuConfig: FreeuConfig + enableLCMLora: boolean; + enableFreeu: boolean; + freeuConfig: FreeuConfig; - // PowerPaint - powerpaintTask: PowerPaintTask + // PowerPaint + powerpaintTask: PowerPaintTask; - // AdjustMask - adjustMaskKernelSize: number -} + // AdjustMask + adjustMaskKernelSize: number; +}; type InteractiveSegState = { - isInteractiveSeg: boolean - tmpInteractiveSegMask: HTMLImageElement | null - clicks: number[][] -} + isInteractiveSeg: boolean; + tmpInteractiveSegMask: HTMLImageElement | null; + clicks: number[][]; +}; type EditorState = { - baseBrushSize: number - brushSizeScale: number - renders: HTMLImageElement[] - lineGroups: LineGroup[] - lastLineGroup: LineGroup - curLineGroup: LineGroup + baseBrushSize: number; + brushSizeScale: number; + renders: HTMLImageElement[]; + lineGroups: LineGroup[]; + lastLineGroup: LineGroup; + curLineGroup: LineGroup; - // mask from interactive-seg or other segmentation models - extraMasks: HTMLImageElement[] - prevExtraMasks: HTMLImageElement[] + // mask from interactive-seg or other segmentation models + extraMasks: HTMLImageElement[]; + prevExtraMasks: HTMLImageElement[]; - temporaryMasks: HTMLImageElement[] - // redo 相关 - redoRenders: HTMLImageElement[] - redoCurLines: Line[] - redoLineGroups: LineGroup[] -} + temporaryMasks: HTMLImageElement[]; + // redo 相关 + redoRenders: HTMLImageElement[]; + redoCurLines: Line[]; + redoLineGroups: LineGroup[]; +}; type AppState = { - file: File | null - paintByExampleFile: File | null - customMask: File | null - imageHeight: number - imageWidth: number - isInpainting: boolean - isPluginRunning: boolean - isAdjustingMask: boolean - windowSize: Size - editorState: EditorState - disableShortCuts: boolean + file: File | null; + paintByExampleFile: File | null; + customMask: File | null; + imageHeight: number; + imageWidth: number; + isInpainting: boolean; + isPluginRunning: boolean; + isAdjustingMask: boolean; + windowSize: Size; + editorState: EditorState; + disableShortCuts: boolean; - interactiveSegState: InteractiveSegState - fileManagerState: FileManagerState + interactiveSegState: InteractiveSegState; + fileManagerState: FileManagerState; - cropperState: CropperState - extenderState: CropperState - isCropperExtenderResizing: boolean + cropperState: CropperState; + extenderState: CropperState; + isCropperExtenderResizing: boolean; - serverConfig: ServerConfig + serverConfig: ServerConfig; - settings: Settings -} + settings: Settings; +}; type AppAction = { - updateAppState: (newState: Partial) => void - setFile: (file: File) => Promise - setCustomFile: (file: File) => void - setIsInpainting: (newValue: boolean) => void - getIsProcessing: () => boolean - setBaseBrushSize: (newValue: number) => void - decreaseBaseBrushSize: () => void - increaseBaseBrushSize: () => void - getBrushSize: () => number - setImageSize: (width: number, height: number) => void + updateAppState: (newState: Partial) => void; + setFile: (file: File) => Promise; + setCustomFile: (file: File) => void; + setIsInpainting: (newValue: boolean) => void; + getIsProcessing: () => boolean; + setBaseBrushSize: (newValue: number) => void; + decreaseBaseBrushSize: () => void; + increaseBaseBrushSize: () => void; + getBrushSize: () => number; + setImageSize: (width: number, height: number) => void; - isSD: () => boolean + isSD: () => boolean; - setCropperX: (newValue: number) => void - setCropperY: (newValue: number) => void - setCropperWidth: (newValue: number) => void - setCropperHeight: (newValue: number) => void + setCropperX: (newValue: number) => void; + setCropperY: (newValue: number) => void; + setCropperWidth: (newValue: number) => void; + setCropperHeight: (newValue: number) => void; - setExtenderX: (newValue: number) => void - setExtenderY: (newValue: number) => void - setExtenderWidth: (newValue: number) => void - setExtenderHeight: (newValue: number) => void + setExtenderX: (newValue: number) => void; + setExtenderY: (newValue: number) => void; + setExtenderWidth: (newValue: number) => void; + setExtenderHeight: (newValue: number) => void; - setIsCropperExtenderResizing: (newValue: boolean) => void - updateExtenderDirection: (newValue: ExtenderDirection) => void - resetExtender: (width: number, height: number) => void - updateExtenderByBuiltIn: (direction: ExtenderDirection, scale: number) => void + setIsCropperExtenderResizing: (newValue: boolean) => void; + updateExtenderDirection: (newValue: ExtenderDirection) => void; + resetExtender: (width: number, height: number) => void; + updateExtenderByBuiltIn: ( + direction: ExtenderDirection, + scale: number, + ) => void; - setServerConfig: (newValue: ServerConfig) => void - setSeed: (newValue: number) => void - updateSettings: (newSettings: Partial) => void - setModel: (newModel: ModelInfo) => void - updateFileManagerState: (newState: Partial) => void - updateInteractiveSegState: (newState: Partial) => void - resetInteractiveSegState: () => void - handleInteractiveSegAccept: () => void - showPromptInput: () => boolean + setServerConfig: (newValue: ServerConfig) => void; + setSeed: (newValue: number) => void; + updateSettings: (newSettings: Partial) => void; + setModel: (newModel: ModelInfo) => void; + updateFileManagerState: (newState: Partial) => void; + updateInteractiveSegState: (newState: Partial) => void; + resetInteractiveSegState: () => void; + handleInteractiveSegAccept: () => void; + showPromptInput: () => boolean; - runInpainting: () => Promise - showPrevMask: () => Promise - hidePrevMask: () => void - runRenderablePlugin: ( - genMask: boolean, - pluginName: string, - params?: PluginParams - ) => Promise + submitMaskImage: () => Promise; + runInpainting: () => Promise; + showPrevMask: () => Promise; + hidePrevMask: () => void; + runRenderablePlugin: ( + genMask: boolean, + pluginName: string, + params?: PluginParams, + ) => Promise; - // EditorState - getCurrentTargetFile: () => Promise - updateEditorState: (newState: Partial) => void - runMannually: () => boolean - handleCanvasMouseDown: (point: Point) => void - handleCanvasMouseMove: (point: Point) => void - cleanCurLineGroup: () => void - resetRedoState: () => void - undo: () => void - redo: () => void - undoDisabled: () => boolean - redoDisabled: () => boolean + // EditorState + getCurrentTargetFile: () => Promise; + updateEditorState: (newState: Partial) => void; + runMannually: () => boolean; + handleCanvasMouseDown: (point: Point) => void; + handleCanvasMouseMove: (point: Point) => void; + cleanCurLineGroup: () => void; + resetRedoState: () => void; + undo: () => void; + redo: () => void; + undoDisabled: () => boolean; + redoDisabled: () => boolean; - adjustMask: (operate: AdjustMaskOperate) => Promise - clearMask: () => void -} + adjustMask: (operate: AdjustMaskOperate) => Promise; + clearMask: () => void; +}; const defaultValues: AppState = { - file: null, - paintByExampleFile: null, - customMask: null, - imageHeight: 0, - imageWidth: 0, - isInpainting: false, - isPluginRunning: false, - isAdjustingMask: false, - disableShortCuts: false, + file: null, + paintByExampleFile: null, + customMask: null, + imageHeight: 0, + imageWidth: 0, + isInpainting: false, + isPluginRunning: false, + isAdjustingMask: false, + disableShortCuts: false, - windowSize: { - height: 600, - width: 800, - }, - editorState: { - baseBrushSize: DEFAULT_BRUSH_SIZE, - brushSizeScale: 1, - renders: [], - extraMasks: [], - prevExtraMasks: [], - temporaryMasks: [], - lineGroups: [], - lastLineGroup: [], - curLineGroup: [], - redoRenders: [], - redoCurLines: [], - redoLineGroups: [], - }, + windowSize: { + height: 600, + width: 800, + }, + editorState: { + baseBrushSize: DEFAULT_BRUSH_SIZE, + brushSizeScale: 1, + renders: [], + extraMasks: [], + prevExtraMasks: [], + temporaryMasks: [], + lineGroups: [], + lastLineGroup: [], + curLineGroup: [], + redoRenders: [], + redoCurLines: [], + redoLineGroups: [], + }, - interactiveSegState: { - isInteractiveSeg: false, - tmpInteractiveSegMask: null, - clicks: [], - }, + interactiveSegState: { + isInteractiveSeg: false, + tmpInteractiveSegMask: null, + clicks: [], + }, - cropperState: { - x: 0, - y: 0, - width: 512, - height: 512, - }, - extenderState: { - x: 0, - y: 0, - width: 512, - height: 512, - }, - isCropperExtenderResizing: false, + cropperState: { + x: 0, + y: 0, + width: 512, + height: 512, + }, + extenderState: { + x: 0, + y: 0, + width: 512, + height: 512, + }, + isCropperExtenderResizing: false, - fileManagerState: { - sortBy: SortBy.CTIME, - sortOrder: SortOrder.DESCENDING, - layout: "masonry", - searchText: "", - inputDirectory: "", - outputDirectory: "", - }, - serverConfig: { - plugins: [], - modelInfos: [], - removeBGModel: "briaai/RMBG-1.4", - removeBGModels: [], - realesrganModel: "realesr-general-x4v3", - realesrganModels: [], - interactiveSegModel: "vit_b", - interactiveSegModels: [], - enableFileManager: false, - enableAutoSaving: false, - enableControlnet: false, - controlnetMethod: "lllyasviel/control_v11p_sd15_canny", - disableModelSwitch: false, - isDesktop: false, - samplers: ["DPM++ 2M SDE Karras"], - }, - settings: { - model: { - name: "lama", - path: "lama", - model_type: "inpaint", - support_controlnet: false, - support_strength: false, - support_outpainting: false, - controlnets: [], - support_freeu: false, - support_lcm_lora: false, - is_single_file_diffusers: false, - need_prompt: false, - }, - enableControlnet: false, - showCropper: false, - showExtender: false, - extenderDirection: ExtenderDirection.xy, - enableDownloadMask: false, - enableManualInpainting: false, - enableUploadMask: false, - enableAutoExtractPrompt: true, - ldmSteps: 30, - ldmSampler: LDMSampler.ddim, - zitsWireframe: true, - cv2Radius: 5, - cv2Flag: CV2Flag.INPAINT_NS, - prompt: "", - negativePrompt: DEFAULT_NEGATIVE_PROMPT, - seed: 42, - seedFixed: false, - sdMaskBlur: 12, - sdStrength: 1.0, - sdSteps: 50, - sdGuidanceScale: 7.5, - sdSampler: "DPM++ 2M", - sdMatchHistograms: false, - sdScale: 1.0, - p2pImageGuidanceScale: 1.5, - controlnetConditioningScale: 0.4, - controlnetMethod: "lllyasviel/control_v11p_sd15_canny", - enableLCMLora: false, - enableFreeu: false, - freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 }, - powerpaintTask: PowerPaintTask.text_guided, - adjustMaskKernelSize: 12, - }, -} + fileManagerState: { + sortBy: SortBy.CTIME, + sortOrder: SortOrder.DESCENDING, + layout: "masonry", + searchText: "", + inputDirectory: "", + outputDirectory: "", + }, + serverConfig: { + plugins: [], + modelInfos: [], + removeBGModel: "briaai/RMBG-1.4", + removeBGModels: [], + realesrganModel: "realesr-general-x4v3", + realesrganModels: [], + interactiveSegModel: "vit_b", + interactiveSegModels: [], + enableFileManager: false, + enableAutoSaving: false, + enableControlnet: false, + controlnetMethod: "lllyasviel/control_v11p_sd15_canny", + disableModelSwitch: false, + isDesktop: false, + samplers: ["DPM++ 2M SDE Karras"], + }, + settings: { + model: { + name: "lama", + path: "lama", + model_type: "inpaint", + support_controlnet: false, + support_strength: false, + support_outpainting: false, + controlnets: [], + support_freeu: false, + support_lcm_lora: false, + is_single_file_diffusers: false, + need_prompt: false, + }, + enableControlnet: false, + showCropper: false, + showExtender: false, + extenderDirection: ExtenderDirection.xy, + enableDownloadMask: false, + enableManualInpainting: false, + enableUploadMask: false, + enableAutoExtractPrompt: true, + ldmSteps: 30, + ldmSampler: LDMSampler.ddim, + zitsWireframe: true, + cv2Radius: 5, + cv2Flag: CV2Flag.INPAINT_NS, + prompt: "", + negativePrompt: DEFAULT_NEGATIVE_PROMPT, + seed: 42, + seedFixed: false, + sdMaskBlur: 12, + sdStrength: 1.0, + sdSteps: 50, + sdGuidanceScale: 7.5, + sdSampler: "DPM++ 2M", + sdMatchHistograms: false, + sdScale: 1.0, + p2pImageGuidanceScale: 1.5, + controlnetConditioningScale: 0.4, + controlnetMethod: "lllyasviel/control_v11p_sd15_canny", + enableLCMLora: false, + enableFreeu: false, + freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 }, + powerpaintTask: PowerPaintTask.text_guided, + adjustMaskKernelSize: 12, + }, +}; export const useStore = createWithEqualityFn()( - persist( - immer((set, get) => ({ - ...defaultValues, - - showPrevMask: async () => { - if (get().settings.showExtender) { - return - } - const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } = - get().editorState - if (curLineGroup.length !== 0 || extraMasks.length !== 0) { - return - } - const { imageWidth, imageHeight } = get() - - const maskCanvas = generateMask( - imageWidth, - imageHeight, - [lastLineGroup], - prevExtraMasks, - BRUSH_COLOR - ) - try { - const maskImage = await canvasToImage(maskCanvas) - set((state) => { - state.editorState.temporaryMasks.push(castDraft(maskImage)) - }) - } catch (e) { - console.error(e) - return - } - }, - hidePrevMask: () => { - set((state) => { - state.editorState.temporaryMasks = [] - }) - }, - - getCurrentTargetFile: async (): Promise => { - const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数 - const renders = get().editorState.renders - - let targetFile = file - if (renders.length > 0) { - const lastRender = renders[renders.length - 1] - targetFile = await srcToFile( - lastRender.currentSrc, - file.name, - file.type - ) - } - return targetFile - }, - - runInpainting: async () => { - const { - isInpainting, - file, - paintByExampleFile, - imageWidth, - imageHeight, - settings, - cropperState, - extenderState, - } = get() - if (isInpainting || file === null) { - return - } - if ( - get().settings.model.support_outpainting && - settings.showExtender && - extenderState.height === imageHeight && - extenderState.width === imageWidth - ) { - return - } - - const { - lastLineGroup, - curLineGroup, - lineGroups, - renders, - prevExtraMasks, - extraMasks, - } = get().editorState - - const useLastLineGroup = - curLineGroup.length === 0 && - extraMasks.length === 0 && - !settings.showExtender - - // useLastLineGroup 的影响 - // 1. 使用上一次的 mask - // 2. 结果替换当前 render - let maskImages: HTMLImageElement[] = [] - let maskLineGroup: LineGroup = [] - if (useLastLineGroup === true) { - maskLineGroup = lastLineGroup - maskImages = prevExtraMasks - } else { - maskLineGroup = curLineGroup - maskImages = extraMasks - } - - if ( - maskLineGroup.length === 0 && - maskImages === null && - !settings.showExtender - ) { - toast({ - variant: "destructive", - description: "Please draw mask on picture", - }) - return - } - - const newLineGroups = [...lineGroups, maskLineGroup] - - set((state) => { - state.isInpainting = true - }) - - let targetFile = file - if (useLastLineGroup === true) { - // renders.length == 1 还是用原来的 - if (renders.length > 1) { - const lastRender = renders[renders.length - 2] - targetFile = await srcToFile( - lastRender.currentSrc, - file.name, - file.type - ) - } - } else if (renders.length > 0) { - const lastRender = renders[renders.length - 1] - targetFile = await srcToFile( - lastRender.currentSrc, - file.name, - file.type - ) - } - - const maskCanvas = generateMask( - imageWidth, - imageHeight, - [maskLineGroup], - maskImages, - BRUSH_COLOR - ) - if (useLastLineGroup) { - const temporaryMask = await canvasToImage(maskCanvas) - set((state) => { - state.editorState.temporaryMasks = castDraft([temporaryMask]) - }) - } - - try { - const res = await inpaint( - targetFile, - settings, - cropperState, - extenderState, - dataURItoBlob(maskCanvas.toDataURL()), - paintByExampleFile - ) - - const { blob, seed } = res - if (seed) { - get().setSeed(parseInt(seed, 10)) - } - const newRender = new Image() - await loadImage(newRender, blob) - const newRenders = [...renders, newRender] - get().setImageSize(newRender.width, newRender.height) - get().updateEditorState({ - renders: newRenders, - lineGroups: newLineGroups, - lastLineGroup: maskLineGroup, - curLineGroup: [], - extraMasks: [], - prevExtraMasks: maskImages, - }) - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - } - - get().resetRedoState() - set((state) => { - state.isInpainting = false - state.editorState.temporaryMasks = [] - }) - }, - - runRenderablePlugin: async ( - genMask: boolean, - pluginName: string, - params: PluginParams = { upscale: 1 } - ) => { - const { renders, lineGroups } = get().editorState - set((state) => { - state.isPluginRunning = true - }) - - try { - const start = new Date() - const targetFile = await get().getCurrentTargetFile() - const res = await runPlugin( - genMask, - pluginName, - targetFile, - params.upscale - ) - const { blob } = res - - if (!genMask) { - const newRender = new Image() - await loadImage(newRender, blob) - get().setImageSize(newRender.width, newRender.height) - const newRenders = [...renders, newRender] - const newLineGroups = [...lineGroups, []] - get().updateEditorState({ - renders: newRenders, - lineGroups: newLineGroups, - }) - } else { - const newMask = new Image() - await loadImage(newMask, blob) - set((state) => { - state.editorState.extraMasks.push(castDraft(newMask)) - }) - } - const end = new Date() - const time = end.getTime() - start.getTime() - toast({ - description: `Run ${pluginName} successfully in ${time / 1000}s`, - }) - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - } - set((state) => { - state.isPluginRunning = false - }) - }, - - // Edirot State // - updateEditorState: (newState: Partial) => { - set((state) => { - state.editorState = castDraft({ ...state.editorState, ...newState }) - }) - }, - - cleanCurLineGroup: () => { - get().updateEditorState({ curLineGroup: [] }) - }, - - handleCanvasMouseDown: (point: Point) => { - let lineGroup: LineGroup = [] - const state = get() - if (state.runMannually()) { - lineGroup = [...state.editorState.curLineGroup] - } - lineGroup.push({ size: state.getBrushSize(), pts: [point] }) - set((state) => { - state.editorState.curLineGroup = lineGroup - }) - }, - - handleCanvasMouseMove: (point: Point) => { - set((state) => { - const curLineGroup = state.editorState.curLineGroup - if (curLineGroup.length) { - curLineGroup[curLineGroup.length - 1].pts.push(point) - } - }) - }, - - runMannually: (): boolean => { - const state = get() - return ( - state.settings.enableManualInpainting || - state.settings.model.model_type !== MODEL_TYPE_INPAINT - ) - }, - - getIsProcessing: (): boolean => { - return ( - get().isInpainting || get().isPluginRunning || get().isAdjustingMask - ) - }, - - isSD: (): boolean => { - return get().settings.model.model_type !== MODEL_TYPE_INPAINT - }, - - // undo/redo - - undoDisabled: (): boolean => { - const editorState = get().editorState - if (editorState.renders.length > 0) { - return false - } - if (get().runMannually()) { - if (editorState.curLineGroup.length === 0) { - return true - } - } else if (editorState.renders.length === 0) { - return true - } - return false - }, - - undo: () => { - if ( - get().runMannually() && - get().editorState.curLineGroup.length !== 0 - ) { - // undoStroke - set((state) => { - const editorState = state.editorState - if (editorState.curLineGroup.length === 0) { - return - } - editorState.lastLineGroup = [] - const lastLine = editorState.curLineGroup.pop()! - editorState.redoCurLines.push(lastLine) - }) - } else { - set((state) => { - const editorState = state.editorState - if ( - editorState.renders.length === 0 || - editorState.lineGroups.length === 0 - ) { - return - } - const lastLineGroup = editorState.lineGroups.pop()! - editorState.redoLineGroups.push(lastLineGroup) - editorState.redoCurLines = [] - editorState.curLineGroup = [] - - const lastRender = editorState.renders.pop()! - editorState.redoRenders.push(lastRender) - }) - } - }, - - redoDisabled: (): boolean => { - const editorState = get().editorState - if (editorState.redoRenders.length > 0) { - return false - } - if (get().runMannually()) { - if (editorState.redoCurLines.length === 0) { - return true - } - } else if (editorState.redoRenders.length === 0) { - return true - } - return false - }, - - redo: () => { - if ( - get().runMannually() && - get().editorState.redoCurLines.length !== 0 - ) { - set((state) => { - const editorState = state.editorState - if (editorState.redoCurLines.length === 0) { - return - } - const line = editorState.redoCurLines.pop()! - editorState.curLineGroup.push(line) - }) - } else { - set((state) => { - const editorState = state.editorState - if ( - editorState.redoRenders.length === 0 || - editorState.redoLineGroups.length === 0 - ) { - return - } - const lastLineGroup = editorState.redoLineGroups.pop()! - editorState.lineGroups.push(lastLineGroup) - editorState.curLineGroup = [] - - const lastRender = editorState.redoRenders.pop()! - editorState.renders.push(lastRender) - }) - } - }, - - resetRedoState: () => { - set((state) => { - state.editorState.redoCurLines = [] - state.editorState.redoLineGroups = [] - state.editorState.redoRenders = [] - }) - }, - - //****// - - updateAppState: (newState: Partial) => { - set(() => newState) - }, - - getBrushSize: (): number => { - return ( - get().editorState.baseBrushSize * get().editorState.brushSizeScale - ) - }, - - showPromptInput: (): boolean => { - const model = get().settings.model - return ( - model.model_type !== MODEL_TYPE_INPAINT && - model.name !== PAINT_BY_EXAMPLE - ) - }, - - setServerConfig: (newValue: ServerConfig) => { - set((state) => { - state.serverConfig = newValue - state.settings.enableControlnet = newValue.enableControlnet - state.settings.controlnetMethod = newValue.controlnetMethod - }) - }, - - updateSettings: (newSettings: Partial) => { - set((state) => { - state.settings = { - ...state.settings, - ...newSettings, - } - }) - }, - - setModel: (newModel: ModelInfo) => { - set((state) => { - state.settings.model = newModel - - if ( - newModel.support_controlnet && - !newModel.controlnets.includes(state.settings.controlnetMethod) - ) { - state.settings.controlnetMethod = newModel.controlnets[0] - } - }) - }, - - updateFileManagerState: (newState: Partial) => { - set((state) => { - state.fileManagerState = { - ...state.fileManagerState, - ...newState, - } - }) - }, - - updateInteractiveSegState: (newState: Partial) => { - set((state) => { - return { - ...state, - interactiveSegState: { - ...state.interactiveSegState, - ...newState, - }, - } - }) - }, - - resetInteractiveSegState: () => { - get().updateInteractiveSegState(defaultValues.interactiveSegState) - }, - - handleInteractiveSegAccept: () => { - set((state) => { - if (state.interactiveSegState.tmpInteractiveSegMask) { - state.editorState.extraMasks.push( - castDraft(state.interactiveSegState.tmpInteractiveSegMask) - ) - } - state.interactiveSegState = castDraft({ - ...defaultValues.interactiveSegState, - }) - }) - }, - - setIsInpainting: (newValue: boolean) => - set((state) => { - state.isInpainting = newValue - }), - - setFile: async (file: File) => { - if (get().settings.enableAutoExtractPrompt) { - try { - const res = await getGenInfo(file) - if (res.prompt) { - set((state) => { - state.settings.prompt = res.prompt - }) - } - if (res.negative_prompt) { - set((state) => { - state.settings.negativePrompt = res.negative_prompt - }) - } - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - } - } - set((state) => { - state.file = file - state.interactiveSegState = castDraft( - defaultValues.interactiveSegState - ) - state.editorState = castDraft(defaultValues.editorState) - state.cropperState = defaultValues.cropperState - }) - }, - - setCustomFile: (file: File) => - set((state) => { - state.customMask = file - }), - - setBaseBrushSize: (newValue: number) => - set((state) => { - state.editorState.baseBrushSize = newValue - }), - - decreaseBaseBrushSize: () => { - const baseBrushSize = get().editorState.baseBrushSize - let newBrushSize = baseBrushSize - if (baseBrushSize > 10) { - newBrushSize = baseBrushSize - 10 - } - if (baseBrushSize <= 10 && baseBrushSize > 0) { - newBrushSize = baseBrushSize - 3 - } - get().setBaseBrushSize(newBrushSize) - }, - - increaseBaseBrushSize: () => { - const baseBrushSize = get().editorState.baseBrushSize - const newBrushSize = Math.min(baseBrushSize + 10, MAX_BRUSH_SIZE) - get().setBaseBrushSize(newBrushSize) - }, - - setImageSize: (width: number, height: number) => { - // 根据图片尺寸调整 brushSize 的 scale - set((state) => { - state.imageWidth = width - state.imageHeight = height - state.editorState.brushSizeScale = - Math.max(Math.min(width, height), 512) / 512 - }) - get().resetExtender(width, height) - }, - - setCropperX: (newValue: number) => - set((state) => { - state.cropperState.x = newValue - }), - - setCropperY: (newValue: number) => - set((state) => { - state.cropperState.y = newValue - }), - - setCropperWidth: (newValue: number) => - set((state) => { - state.cropperState.width = newValue - }), - - setCropperHeight: (newValue: number) => - set((state) => { - state.cropperState.height = newValue - }), - - setExtenderX: (newValue: number) => - set((state) => { - state.extenderState.x = newValue - }), - - setExtenderY: (newValue: number) => - set((state) => { - state.extenderState.y = newValue - }), - - setExtenderWidth: (newValue: number) => - set((state) => { - state.extenderState.width = newValue - }), - - setExtenderHeight: (newValue: number) => - set((state) => { - state.extenderState.height = newValue - }), - - setIsCropperExtenderResizing: (newValue: boolean) => - set((state) => { - state.isCropperExtenderResizing = newValue - }), - - updateExtenderDirection: (newValue: ExtenderDirection) => { - console.log( - `updateExtenderDirection: ${JSON.stringify(get().extenderState)}` - ) - set((state) => { - state.settings.extenderDirection = newValue - state.extenderState.x = 0 - state.extenderState.y = 0 - state.extenderState.width = state.imageWidth - state.extenderState.height = state.imageHeight - }) - get().updateExtenderByBuiltIn(newValue, 1.5) - }, - - updateExtenderByBuiltIn: ( - direction: ExtenderDirection, - scale: number - ) => { - const newExtenderState = { ...defaultValues.extenderState } - let { x, y, width, height } = newExtenderState - const { imageWidth, imageHeight } = get() - width = imageWidth - height = imageHeight - - switch (direction) { - case ExtenderDirection.x: - x = -Math.ceil((imageWidth * (scale - 1)) / 2) - width = Math.ceil(imageWidth * scale) - break - case ExtenderDirection.y: - y = -Math.ceil((imageHeight * (scale - 1)) / 2) - height = Math.ceil(imageHeight * scale) - break - case ExtenderDirection.xy: - x = -Math.ceil((imageWidth * (scale - 1)) / 2) - y = -Math.ceil((imageHeight * (scale - 1)) / 2) - width = Math.ceil(imageWidth * scale) - height = Math.ceil(imageHeight * scale) - break - default: - break - } - - set((state) => { - state.extenderState.x = x - state.extenderState.y = y - state.extenderState.width = width - state.extenderState.height = height - }) - }, - - resetExtender: (width: number, height: number) => { - set((state) => { - state.extenderState.x = 0 - state.extenderState.y = 0 - state.extenderState.width = width - state.extenderState.height = height - }) - }, - - setSeed: (newValue: number) => - set((state) => { - state.settings.seed = newValue - }), - - adjustMask: async (operate: AdjustMaskOperate) => { - const { imageWidth, imageHeight } = get() - const { curLineGroup, extraMasks } = get().editorState - const { adjustMaskKernelSize } = get().settings - if (curLineGroup.length === 0 && extraMasks.length === 0) { - return - } - - set((state) => { - state.isAdjustingMask = true - }) - - const maskCanvas = generateMask( - imageWidth, - imageHeight, - [curLineGroup], - extraMasks, - BRUSH_COLOR - ) - const maskBlob = dataURItoBlob(maskCanvas.toDataURL()) - const newMaskBlob = await postAdjustMask( - maskBlob, - operate, - adjustMaskKernelSize - ) - const newMask = await blobToImage(newMaskBlob) - - // TODO: currently ignore stroke undo/redo - set((state) => { - state.editorState.extraMasks = [castDraft(newMask)] - state.editorState.curLineGroup = [] - }) - - set((state) => { - state.isAdjustingMask = false - }) - }, - clearMask: () => { - set((state) => { - state.editorState.extraMasks = [] - state.editorState.curLineGroup = [] - }) - }, - })), - { - name: "ZUSTAND_STATE", // name of the item in the storage (must be unique) - version: 1, - partialize: (state) => - Object.fromEntries( - Object.entries(state).filter(([key]) => - ["fileManagerState", "settings"].includes(key) - ) - ), - } - ), - shallow -) + persist( + immer((set, get) => ({ + ...defaultValues, + + showPrevMask: async () => { + if (get().settings.showExtender) { + return; + } + const { + lastLineGroup, + curLineGroup, + prevExtraMasks, + extraMasks, + } = get().editorState; + if (curLineGroup.length !== 0 || extraMasks.length !== 0) { + return; + } + const { imageWidth, imageHeight } = get(); + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + [lastLineGroup], + prevExtraMasks, + BRUSH_COLOR, + ); + try { + const maskImage = await canvasToImage(maskCanvas); + set((state) => { + state.editorState.temporaryMasks.push( + castDraft(maskImage), + ); + }); + } catch (e) { + console.error(e); + return; + } + }, + hidePrevMask: () => { + set((state) => { + state.editorState.temporaryMasks = []; + }); + }, + + getCurrentTargetFile: async (): Promise => { + const file = get().file!; // 一定是在 file 加载了以后才可能调用这个函数 + const renders = get().editorState.renders; + + let targetFile = file; + if (renders.length > 0) { + const lastRender = renders[renders.length - 1]; + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type, + ); + } + return targetFile; + }, + + submitMaskImage: async () => { + const { file, imageWidth, imageHeight, settings } = get(); + if (file === null) { + return; + } + + const { + lastLineGroup, + curLineGroup, + lineGroups, + renders, + prevExtraMasks, + extraMasks, + } = get().editorState; + + const useLastLineGroup = + curLineGroup.length === 0 && + extraMasks.length === 0 && + !settings.showExtender; + + let maskImages: HTMLImageElement[] = []; + let maskLineGroup: LineGroup = []; + if (useLastLineGroup === true) { + maskLineGroup = lastLineGroup; + maskImages = prevExtraMasks; + } else { + maskLineGroup = curLineGroup; + maskImages = extraMasks; + } + + if ( + maskLineGroup.length === 0 && + maskImages === null && + !settings.showExtender + ) { + toast({ + variant: "destructive", + description: "Please draw mask on picture", + }); + return; + } + + const newLineGroups = [...lineGroups, maskLineGroup]; + + set((state) => { + state.isInpainting = true; + }); + + let targetFile = file; + if (useLastLineGroup === true) { + if (renders.length > 1) { + const lastRender = renders[renders.length - 2]; + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type, + ); + } + } else if (renders.length > 0) { + const lastRender = renders[renders.length - 1]; + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type, + ); + } + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + [maskLineGroup], + maskImages, + BRUSH_COLOR, + ); + if (useLastLineGroup) { + const temporaryMask = await canvasToImage(maskCanvas); + set((state) => { + state.editorState.temporaryMasks = castDraft([ + temporaryMask, + ]); + }); + } + + try { + const res = await submitMask( + targetFile, + dataURItoBlob(maskCanvas.toDataURL()), + ); + + const { blob, seed } = res; + if (seed) { + get().setSeed(parseInt(seed, 10)); + } + const newRender = new Image(); + await loadImage(newRender, blob); + const newRenders = [...renders, newRender]; + get().setImageSize(newRender.width, newRender.height); + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + lastLineGroup: maskLineGroup, + curLineGroup: [], + extraMasks: [], + prevExtraMasks: maskImages, + }); + } catch (e: any) { + toast({ + // variant: "default", + description: e.message ? e.message : e.toString(), + }); + } + + get().resetRedoState(); + set((state) => { + state.isInpainting = false; + state.editorState.temporaryMasks = []; + }); + }, + + runInpainting: async () => { + const { + isInpainting, + file, + paintByExampleFile, + imageWidth, + imageHeight, + settings, + cropperState, + extenderState, + } = get(); + if (isInpainting || file === null) { + return; + } + if ( + get().settings.model.support_outpainting && + settings.showExtender && + extenderState.height === imageHeight && + extenderState.width === imageWidth + ) { + return; + } + + const { + lastLineGroup, + curLineGroup, + lineGroups, + renders, + prevExtraMasks, + extraMasks, + } = get().editorState; + + const useLastLineGroup = + curLineGroup.length === 0 && + extraMasks.length === 0 && + !settings.showExtender; + + // useLastLineGroup 的影响 + // 1. 使用上一次的 mask + // 2. 结果替换当前 render + let maskImages: HTMLImageElement[] = []; + let maskLineGroup: LineGroup = []; + if (useLastLineGroup === true) { + maskLineGroup = lastLineGroup; + maskImages = prevExtraMasks; + } else { + maskLineGroup = curLineGroup; + maskImages = extraMasks; + } + + if ( + maskLineGroup.length === 0 && + maskImages === null && + !settings.showExtender + ) { + toast({ + variant: "destructive", + description: "Please draw mask on picture", + }); + return; + } + + const newLineGroups = [...lineGroups, maskLineGroup]; + + set((state) => { + state.isInpainting = true; + }); + + let targetFile = file; + if (useLastLineGroup === true) { + // renders.length == 1 还是用原来的 + if (renders.length > 1) { + const lastRender = renders[renders.length - 2]; + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type, + ); + } + } else if (renders.length > 0) { + const lastRender = renders[renders.length - 1]; + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type, + ); + } + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + [maskLineGroup], + maskImages, + BRUSH_COLOR, + ); + if (useLastLineGroup) { + const temporaryMask = await canvasToImage(maskCanvas); + set((state) => { + state.editorState.temporaryMasks = castDraft([ + temporaryMask, + ]); + }); + } + + try { + const res = await inpaint( + targetFile, + settings, + cropperState, + extenderState, + dataURItoBlob(maskCanvas.toDataURL()), + paintByExampleFile, + ); + + const { blob, seed } = res; + if (seed) { + get().setSeed(parseInt(seed, 10)); + } + const newRender = new Image(); + await loadImage(newRender, blob); + const newRenders = [...renders, newRender]; + get().setImageSize(newRender.width, newRender.height); + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + lastLineGroup: maskLineGroup, + curLineGroup: [], + extraMasks: [], + prevExtraMasks: maskImages, + }); + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }); + } + + get().resetRedoState(); + set((state) => { + state.isInpainting = false; + state.editorState.temporaryMasks = []; + }); + }, + + runRenderablePlugin: async ( + genMask: boolean, + pluginName: string, + params: PluginParams = { upscale: 1 }, + ) => { + const { renders, lineGroups } = get().editorState; + set((state) => { + state.isPluginRunning = true; + }); + + try { + const start = new Date(); + const targetFile = await get().getCurrentTargetFile(); + const res = await runPlugin( + genMask, + pluginName, + targetFile, + params.upscale, + ); + const { blob } = res; + + if (!genMask) { + const newRender = new Image(); + await loadImage(newRender, blob); + get().setImageSize(newRender.width, newRender.height); + const newRenders = [...renders, newRender]; + const newLineGroups = [...lineGroups, []]; + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + }); + } else { + const newMask = new Image(); + await loadImage(newMask, blob); + set((state) => { + state.editorState.extraMasks.push( + castDraft(newMask), + ); + }); + } + const end = new Date(); + const time = end.getTime() - start.getTime(); + toast({ + description: `Run ${pluginName} successfully in ${ + time / 1000 + }s`, + }); + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }); + } + set((state) => { + state.isPluginRunning = false; + }); + }, + + // Edirot State // + updateEditorState: (newState: Partial) => { + set((state) => { + state.editorState = castDraft({ + ...state.editorState, + ...newState, + }); + }); + }, + + cleanCurLineGroup: () => { + get().updateEditorState({ curLineGroup: [] }); + }, + + handleCanvasMouseDown: (point: Point) => { + let lineGroup: LineGroup = []; + const state = get(); + if (state.runMannually()) { + lineGroup = [...state.editorState.curLineGroup]; + } + lineGroup.push({ size: state.getBrushSize(), pts: [point] }); + set((state) => { + state.editorState.curLineGroup = lineGroup; + }); + }, + + handleCanvasMouseMove: (point: Point) => { + set((state) => { + const curLineGroup = state.editorState.curLineGroup; + if (curLineGroup.length) { + curLineGroup[curLineGroup.length - 1].pts.push(point); + } + }); + }, + + runMannually: (): boolean => { + const state = get(); + return ( + state.settings.enableManualInpainting || + state.settings.model.model_type !== MODEL_TYPE_INPAINT + ); + }, + + getIsProcessing: (): boolean => { + return ( + get().isInpainting || + get().isPluginRunning || + get().isAdjustingMask + ); + }, + + isSD: (): boolean => { + return get().settings.model.model_type !== MODEL_TYPE_INPAINT; + }, + + // undo/redo + + undoDisabled: (): boolean => { + const editorState = get().editorState; + if (editorState.renders.length > 0) { + return false; + } + if (get().runMannually()) { + if (editorState.curLineGroup.length === 0) { + return true; + } + } else if (editorState.renders.length === 0) { + return true; + } + return false; + }, + + undo: () => { + if ( + get().runMannually() && + get().editorState.curLineGroup.length !== 0 + ) { + // undoStroke + set((state) => { + const editorState = state.editorState; + if (editorState.curLineGroup.length === 0) { + return; + } + editorState.lastLineGroup = []; + const lastLine = editorState.curLineGroup.pop()!; + editorState.redoCurLines.push(lastLine); + }); + } else { + set((state) => { + const editorState = state.editorState; + if ( + editorState.renders.length === 0 || + editorState.lineGroups.length === 0 + ) { + return; + } + const lastLineGroup = editorState.lineGroups.pop()!; + editorState.redoLineGroups.push(lastLineGroup); + editorState.redoCurLines = []; + editorState.curLineGroup = []; + + const lastRender = editorState.renders.pop()!; + editorState.redoRenders.push(lastRender); + }); + } + }, + + redoDisabled: (): boolean => { + const editorState = get().editorState; + if (editorState.redoRenders.length > 0) { + return false; + } + if (get().runMannually()) { + if (editorState.redoCurLines.length === 0) { + return true; + } + } else if (editorState.redoRenders.length === 0) { + return true; + } + return false; + }, + + redo: () => { + if ( + get().runMannually() && + get().editorState.redoCurLines.length !== 0 + ) { + set((state) => { + const editorState = state.editorState; + if (editorState.redoCurLines.length === 0) { + return; + } + const line = editorState.redoCurLines.pop()!; + editorState.curLineGroup.push(line); + }); + } else { + set((state) => { + const editorState = state.editorState; + if ( + editorState.redoRenders.length === 0 || + editorState.redoLineGroups.length === 0 + ) { + return; + } + const lastLineGroup = editorState.redoLineGroups.pop()!; + editorState.lineGroups.push(lastLineGroup); + editorState.curLineGroup = []; + + const lastRender = editorState.redoRenders.pop()!; + editorState.renders.push(lastRender); + }); + } + }, + + resetRedoState: () => { + set((state) => { + state.editorState.redoCurLines = []; + state.editorState.redoLineGroups = []; + state.editorState.redoRenders = []; + }); + }, + + //****// + + updateAppState: (newState: Partial) => { + set(() => newState); + }, + + getBrushSize: (): number => { + return ( + get().editorState.baseBrushSize * + get().editorState.brushSizeScale + ); + }, + + showPromptInput: (): boolean => { + const model = get().settings.model; + return ( + model.model_type !== MODEL_TYPE_INPAINT && + model.name !== PAINT_BY_EXAMPLE + ); + }, + + setServerConfig: (newValue: ServerConfig) => { + set((state) => { + state.serverConfig = newValue; + state.settings.enableControlnet = newValue.enableControlnet; + state.settings.controlnetMethod = newValue.controlnetMethod; + }); + }, + + updateSettings: (newSettings: Partial) => { + set((state) => { + state.settings = { + ...state.settings, + ...newSettings, + }; + }); + }, + + setModel: (newModel: ModelInfo) => { + set((state) => { + state.settings.model = newModel; + + if ( + newModel.support_controlnet && + !newModel.controlnets.includes( + state.settings.controlnetMethod, + ) + ) { + state.settings.controlnetMethod = + newModel.controlnets[0]; + } + }); + }, + + updateFileManagerState: (newState: Partial) => { + set((state) => { + state.fileManagerState = { + ...state.fileManagerState, + ...newState, + }; + }); + }, + + updateInteractiveSegState: ( + newState: Partial, + ) => { + set((state) => { + return { + ...state, + interactiveSegState: { + ...state.interactiveSegState, + ...newState, + }, + }; + }); + }, + + resetInteractiveSegState: () => { + get().updateInteractiveSegState( + defaultValues.interactiveSegState, + ); + }, + + handleInteractiveSegAccept: () => { + set((state) => { + if (state.interactiveSegState.tmpInteractiveSegMask) { + state.editorState.extraMasks.push( + castDraft( + state.interactiveSegState.tmpInteractiveSegMask, + ), + ); + } + state.interactiveSegState = castDraft({ + ...defaultValues.interactiveSegState, + }); + }); + }, + + setIsInpainting: (newValue: boolean) => + set((state) => { + state.isInpainting = newValue; + }), + + setFile: async (file: File) => { + if (get().settings.enableAutoExtractPrompt) { + try { + const res = await getGenInfo(file); + if (res.prompt) { + set((state) => { + state.settings.prompt = res.prompt; + }); + } + if (res.negative_prompt) { + set((state) => { + state.settings.negativePrompt = + res.negative_prompt; + }); + } + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }); + } + } + set((state) => { + state.file = file; + state.interactiveSegState = castDraft( + defaultValues.interactiveSegState, + ); + state.editorState = castDraft(defaultValues.editorState); + state.cropperState = defaultValues.cropperState; + }); + }, + + setCustomFile: (file: File) => + set((state) => { + state.customMask = file; + }), + + setBaseBrushSize: (newValue: number) => + set((state) => { + state.editorState.baseBrushSize = newValue; + }), + + decreaseBaseBrushSize: () => { + const baseBrushSize = get().editorState.baseBrushSize; + let newBrushSize = baseBrushSize; + if (baseBrushSize > 10) { + newBrushSize = baseBrushSize - 10; + } + if (baseBrushSize <= 10 && baseBrushSize > 0) { + newBrushSize = baseBrushSize - 3; + } + get().setBaseBrushSize(newBrushSize); + }, + + increaseBaseBrushSize: () => { + const baseBrushSize = get().editorState.baseBrushSize; + const newBrushSize = Math.min( + baseBrushSize + 10, + MAX_BRUSH_SIZE, + ); + get().setBaseBrushSize(newBrushSize); + }, + + setImageSize: (width: number, height: number) => { + // 根据图片尺寸调整 brushSize 的 scale + set((state) => { + state.imageWidth = width; + state.imageHeight = height; + state.editorState.brushSizeScale = + Math.max(Math.min(width, height), 512) / 512; + }); + get().resetExtender(width, height); + }, + + setCropperX: (newValue: number) => + set((state) => { + state.cropperState.x = newValue; + }), + + setCropperY: (newValue: number) => + set((state) => { + state.cropperState.y = newValue; + }), + + setCropperWidth: (newValue: number) => + set((state) => { + state.cropperState.width = newValue; + }), + + setCropperHeight: (newValue: number) => + set((state) => { + state.cropperState.height = newValue; + }), + + setExtenderX: (newValue: number) => + set((state) => { + state.extenderState.x = newValue; + }), + + setExtenderY: (newValue: number) => + set((state) => { + state.extenderState.y = newValue; + }), + + setExtenderWidth: (newValue: number) => + set((state) => { + state.extenderState.width = newValue; + }), + + setExtenderHeight: (newValue: number) => + set((state) => { + state.extenderState.height = newValue; + }), + + setIsCropperExtenderResizing: (newValue: boolean) => + set((state) => { + state.isCropperExtenderResizing = newValue; + }), + + updateExtenderDirection: (newValue: ExtenderDirection) => { + console.log( + `updateExtenderDirection: ${JSON.stringify( + get().extenderState, + )}`, + ); + set((state) => { + state.settings.extenderDirection = newValue; + state.extenderState.x = 0; + state.extenderState.y = 0; + state.extenderState.width = state.imageWidth; + state.extenderState.height = state.imageHeight; + }); + get().updateExtenderByBuiltIn(newValue, 1.5); + }, + + updateExtenderByBuiltIn: ( + direction: ExtenderDirection, + scale: number, + ) => { + const newExtenderState = { ...defaultValues.extenderState }; + let { x, y, width, height } = newExtenderState; + const { imageWidth, imageHeight } = get(); + width = imageWidth; + height = imageHeight; + + switch (direction) { + case ExtenderDirection.x: + x = -Math.ceil((imageWidth * (scale - 1)) / 2); + width = Math.ceil(imageWidth * scale); + break; + case ExtenderDirection.y: + y = -Math.ceil((imageHeight * (scale - 1)) / 2); + height = Math.ceil(imageHeight * scale); + break; + case ExtenderDirection.xy: + x = -Math.ceil((imageWidth * (scale - 1)) / 2); + y = -Math.ceil((imageHeight * (scale - 1)) / 2); + width = Math.ceil(imageWidth * scale); + height = Math.ceil(imageHeight * scale); + break; + default: + break; + } + + set((state) => { + state.extenderState.x = x; + state.extenderState.y = y; + state.extenderState.width = width; + state.extenderState.height = height; + }); + }, + + resetExtender: (width: number, height: number) => { + set((state) => { + state.extenderState.x = 0; + state.extenderState.y = 0; + state.extenderState.width = width; + state.extenderState.height = height; + }); + }, + + setSeed: (newValue: number) => + set((state) => { + state.settings.seed = newValue; + }), + + adjustMask: async (operate: AdjustMaskOperate) => { + const { imageWidth, imageHeight } = get(); + const { curLineGroup, extraMasks } = get().editorState; + const { adjustMaskKernelSize } = get().settings; + if (curLineGroup.length === 0 && extraMasks.length === 0) { + return; + } + + set((state) => { + state.isAdjustingMask = true; + }); + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + [curLineGroup], + extraMasks, + BRUSH_COLOR, + ); + const maskBlob = dataURItoBlob(maskCanvas.toDataURL()); + const newMaskBlob = await postAdjustMask( + maskBlob, + operate, + adjustMaskKernelSize, + ); + const newMask = await blobToImage(newMaskBlob); + + // TODO: currently ignore stroke undo/redo + set((state) => { + state.editorState.extraMasks = [castDraft(newMask)]; + state.editorState.curLineGroup = []; + }); + + set((state) => { + state.isAdjustingMask = false; + }); + }, + clearMask: () => { + set((state) => { + state.editorState.extraMasks = []; + state.editorState.curLineGroup = []; + }); + }, + })), + { + name: "ZUSTAND_STATE", // name of the item in the storage (must be unique) + version: 1, + partialize: (state) => + Object.fromEntries( + Object.entries(state).filter(([key]) => + ["fileManagerState", "settings"].includes(key), + ), + ), + }, + ), + shallow, +);