diff --git a/custom-demo/back-end/.gitignore b/custom-demo/back-end/.gitignore new file mode 100644 index 0000000..5f19111 --- /dev/null +++ b/custom-demo/back-end/.gitignore @@ -0,0 +1,14 @@ +.DS_Store +**/__pycache__ +examples/ +.idea/ +.vscode/ +build +!iopaint/app/build +dist/ +IOPaint.egg-info/ +venv/ +tmp/ +# iopaint/web_app/ +iopaint/.venv +iopaint/output_nohup.log \ No newline at end of file diff --git a/custom-demo/back-end/README.md b/custom-demo/back-end/README.md new file mode 100644 index 0000000..f59ecca --- /dev/null +++ b/custom-demo/back-end/README.md @@ -0,0 +1,7 @@ +## Run BE +```shell +cd /home/lama-cleaner/iopaint +source .venv/bin/activate +root@aitool:/home/lama-cleaner/iopaint# nohup python3 ../main.py start --enable-remove-bg > output_nohup.log 2>&1 & +``` + diff --git a/custom-demo/back-end/__init__.py b/custom-demo/back-end/__init__.py new file mode 100644 index 0000000..3accbc5 --- /dev/null +++ b/custom-demo/back-end/__init__.py @@ -0,0 +1,24 @@ +# __init__.py +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068 +os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1" +os.environ["LRU_CACHE_CAPACITY"] = "1" +# prevent CPU memory leak when run model on GPU +# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431 +# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633 +os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1" + + +import warnings + +warnings.simplefilter("ignore", UserWarning) + + +def entry_point(): + # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers + # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 + from iopaint.cli import typer_app + + typer_app() diff --git a/custom-demo/back-end/__main__.py b/custom-demo/back-end/__main__.py new file mode 100644 index 0000000..5a9d51e --- /dev/null +++ b/custom-demo/back-end/__main__.py @@ -0,0 +1,5 @@ +# __main__.py +from iopaint import entry_point + +if __name__ == "__main__": + entry_point() diff --git a/custom-demo/back-end/api.py b/custom-demo/back-end/api.py new file mode 100644 index 0000000..3c9df6a --- /dev/null +++ b/custom-demo/back-end/api.py @@ -0,0 +1,397 @@ +import asyncio +import os +import threading +import time +import traceback +from pathlib import Path +from typing import Optional, Dict, List + +import cv2 +import numpy as np +import socketio +import torch + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + + +import uvicorn +from PIL import Image +from fastapi import APIRouter, FastAPI, Request, UploadFile +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, FileResponse, Response +from fastapi.staticfiles import StaticFiles +from loguru import logger +from socketio import AsyncServer + +from iopaint.file_manager import FileManager +from iopaint.helper import ( + load_img, + decode_base64_to_image, + pil_to_bytes, + numpy_to_bytes, + concat_alpha_channel, + gen_frontend_mask, + adjust_mask, +) +from iopaint.model.utils import torch_gc +from iopaint.model_manager import ModelManager +from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.plugins.remove_bg import RemoveBG +from iopaint.schema import ( + GenInfoResponse, + ApiConfig, + ServerConfigResponse, + SwitchModelRequest, + InpaintRequest, + RunPluginRequest, + SDSampler, + PluginInfo, + AdjustMaskRequest, + RemoveBGModel, + SwitchPluginModelRequest, + ModelInfo, + InteractiveSegModel, + RealESRGANModel, +) + +CURRENT_DIR = Path(__file__).parent.absolute().resolve() +WEB_APP_DIR = CURRENT_DIR / "web_app" + + +def api_middleware(app: FastAPI): + rich_available = False + try: + if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + + console = Console() + rich_available = True + except Exception: + pass + + def handle_exception(request: Request, e: Exception): + err = { + "error": type(e).__name__, + "detail": vars(e).get("detail", ""), + "body": vars(e).get("body", ""), + "errors": str(e), + } + if not isinstance( + e, HTTPException + ): # do not print backtrace on known httpexceptions + message = f"API error: {request.method}: {request.url} {err}" + if rich_available: + print(message) + console.print_exception( + show_locals=True, + max_frames=2, + extra_lines=1, + suppress=[anyio, starlette], + word_wrap=False, + width=min([console.width, 200]), + ) + else: + traceback.print_exc() + return JSONResponse( + status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err) + ) + + @app.middleware("http") + async def exception_handling(request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + return handle_exception(request, e) + + @app.exception_handler(Exception) + async def fastapi_exception_handler(request: Request, e: Exception): + return handle_exception(request, e) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, e: HTTPException): + return handle_exception(request, e) + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_origins": ["*"], + "allow_credentials": True, + "expose_headers": ["X-Seed"] + } + app.add_middleware(CORSMiddleware, **cors_options) + + +global_sio: AsyncServer = None + + +def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}): + # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict + # logger.info(f"diffusion callback: step={step}, timestep={timestep}") + + # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI, + # but for now let's just start a separate event loop. It shouldn't make a difference for single person use + asyncio.run(global_sio.emit("diffusion_progress", {"step": step})) + return {} + + +class Api: + def __init__(self, app: FastAPI, config: ApiConfig): + self.app = app + self.config = config + self.router = APIRouter() + self.queue_lock = threading.Lock() + api_middleware(self.app) + + self.file_manager = self._build_file_manager() + self.plugins = self._build_plugins() + self.model_manager = self._build_model_manager() + + # fmt: off + # self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) + self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) + self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) + self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) + self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) + self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) + self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) + self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) + self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) + self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) + self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) + self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) + self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") + # fmt: on + + global global_sio + self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) + self.app.mount("/ws", self.combined_asgi_app) + global_sio = self.sio + + def add_api_route(self, path: str, endpoint, **kwargs): + return self.app.add_api_route(path, endpoint, **kwargs) + + def api_save_image(self, file: UploadFile): + filename = file.filename + origin_image_bytes = file.file.read() + with open(self.config.output_dir / filename, "wb") as fw: + fw.write(origin_image_bytes) + + def api_current_model(self) -> ModelInfo: + return self.model_manager.current_model + + def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo: + if req.name == self.model_manager.name: + return self.model_manager.current_model + self.model_manager.switch(req.name) + return self.model_manager.current_model + + def api_switch_plugin_model(self, req: SwitchPluginModelRequest): + if req.plugin_name in self.plugins: + self.plugins[req.plugin_name].switch_model(req.model_name) + if req.plugin_name == RemoveBG.name: + self.config.remove_bg_model = req.model_name + if req.plugin_name == RealESRGANUpscaler.name: + self.config.realesrgan_model = req.model_name + if req.plugin_name == InteractiveSeg.name: + self.config.interactive_seg_model = req.model_name + torch_gc() + + def api_server_config(self) -> ServerConfigResponse: + plugins = [] + for it in self.plugins.values(): + plugins.append( + PluginInfo( + name=it.name, + support_gen_image=it.support_gen_image, + support_gen_mask=it.support_gen_mask, + ) + ) + + return ServerConfigResponse( + plugins=plugins, + modelInfos=self.model_manager.scan_models(), + removeBGModel=self.config.remove_bg_model, + removeBGModels=RemoveBGModel.values(), + realesrganModel=self.config.realesrgan_model, + realesrganModels=RealESRGANModel.values(), + interactiveSegModel=self.config.interactive_seg_model, + interactiveSegModels=InteractiveSegModel.values(), + enableFileManager=self.file_manager is not None, + enableAutoSaving=self.config.output_dir is not None, + enableControlnet=self.model_manager.enable_controlnet, + controlnetMethod=self.model_manager.controlnet_method, + disableModelSwitch=False, + isDesktop=False, + samplers=self.api_samplers(), + ) + + def api_input_image(self) -> FileResponse: + if self.config.input and self.config.input.is_file(): + return FileResponse(self.config.input) + raise HTTPException(status_code=404, detail="Input image not found") + + def api_geninfo(self, file: UploadFile) -> GenInfoResponse: + _, _, info = load_img(file.file.read(), return_info=True) + parts = info.get("parameters", "").split("Negative prompt: ") + prompt = parts[0].strip() + negative_prompt = "" + if len(parts) > 1: + negative_prompt = parts[1].split("\n")[0].strip() + return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) + + def api_inpaint(self, req: InpaintRequest): + image, alpha_channel, infos = decode_base64_to_image(req.image) + mask, _, _ = decode_base64_to_image(req.mask, gray=True) + + mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] + if image.shape[:2] != mask.shape[:2]: + raise HTTPException( + 400, + detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", + ) + + if req.paint_by_example_example_image: + paint_by_example_image, _, _ = decode_base64_to_image( + req.paint_by_example_example_image + ) + + start = time.time() + rgb_np_img = self.model_manager(image, mask, req) + logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") + torch_gc() + + rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) + rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) + + ext = "png" + res_img_bytes = pil_to_bytes( + Image.fromarray(rgb_res), + ext=ext, + quality=self.config.quality, + infos=infos, + ) + + asyncio.run(self.sio.emit("diffusion_finish")) + + return Response( + content=res_img_bytes, + media_type=f"image/{ext}", + headers={"X-Seed": str(req.sd_seed)}, + ) + + def api_run_plugin_gen_image(self, req: RunPluginRequest): + ext = "png" + if req.name not in self.plugins: + raise HTTPException(status_code=422, detail="Plugin not found") + if not self.plugins[req.name].support_gen_image: + raise HTTPException( + status_code=422, detail="Plugin does not support output image" + ) + rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) + torch_gc() + + if bgr_or_rgba_np_img.shape[2] == 4: + rgba_np_img = bgr_or_rgba_np_img + else: + rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB) + rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) + + return Response( + content=pil_to_bytes( + Image.fromarray(rgba_np_img), + ext=ext, + quality=self.config.quality, + infos=infos, + ), + media_type=f"image/{ext}", + ) + + def api_run_plugin_gen_mask(self, req: RunPluginRequest): + if req.name not in self.plugins: + raise HTTPException(status_code=422, detail="Plugin not found") + if not self.plugins[req.name].support_gen_mask: + raise HTTPException( + status_code=422, detail="Plugin does not support output image" + ) + rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) + torch_gc() + res_mask = gen_frontend_mask(bgr_or_gray_mask) + return Response( + content=numpy_to_bytes(res_mask, "png"), + media_type="image/png", + ) + + def api_samplers(self) -> List[str]: + return [member.value for member in SDSampler.__members__.values()] + + def api_adjust_mask(self, req: AdjustMaskRequest): + mask, _, _ = decode_base64_to_image(req.mask, gray=True) + mask = adjust_mask(mask, req.kernel_size, req.operate) + return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") + + def launch(self): + self.app.include_router(self.router) + uvicorn.run( + self.combined_asgi_app, + host=self.config.host, + port=self.config.port, + timeout_keep_alive=999999999, + ) + + def _build_file_manager(self) -> Optional[FileManager]: + if self.config.input and self.config.input.is_dir(): + logger.info( + f"Input is directory, initialize file manager {self.config.input}" + ) + + return FileManager( + app=self.app, + input_dir=self.config.input, + output_dir=self.config.output_dir, + ) + return None + + def _build_plugins(self) -> Dict[str, BasePlugin]: + return build_plugins( + self.config.enable_interactive_seg, + self.config.interactive_seg_model, + self.config.interactive_seg_device, + self.config.enable_remove_bg, + self.config.remove_bg_model, + self.config.enable_anime_seg, + self.config.enable_realesrgan, + self.config.realesrgan_device, + self.config.realesrgan_model, + self.config.enable_gfpgan, + self.config.gfpgan_device, + self.config.enable_restoreformer, + self.config.restoreformer_device, + self.config.no_half, + ) + + def _build_model_manager(self): + return ModelManager( + name=self.config.model, + device=torch.device(self.config.device), + no_half=self.config.no_half, + low_mem=self.config.low_mem, + disable_nsfw=self.config.disable_nsfw_checker, + sd_cpu_textencoder=self.config.cpu_textencoder, + local_files_only=self.config.local_files_only, + cpu_offload=self.config.cpu_offload, + callback=diffuser_callback, + ) diff --git a/custom-demo/back-end/batch_processing.py b/custom-demo/back-end/batch_processing.py new file mode 100644 index 0000000..393a720 --- /dev/null +++ b/custom-demo/back-end/batch_processing.py @@ -0,0 +1,127 @@ +import json +from pathlib import Path +from typing import Dict, Optional + +import cv2 +import psutil +from PIL import Image +from loguru import logger +from rich.console import Console +from rich.progress import ( + Progress, + SpinnerColumn, + TimeElapsedColumn, + MofNCompleteColumn, + TextColumn, + BarColumn, + TaskProgressColumn, +) + +from iopaint.helper import pil_to_bytes +from iopaint.model.utils import torch_gc +from iopaint.model_manager import ModelManager +from iopaint.schema import InpaintRequest + + +def glob_images(path: Path) -> Dict[str, Path]: + # png/jpg/jpeg + if path.is_file(): + return {path.stem: path} + elif path.is_dir(): + res = {} + for it in path.glob("*.*"): + if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: + res[it.stem] = it + return res + + +def batch_inpaint( + model: str, + device, + image: Path, + mask: Path, + output: Path, + config: Optional[Path] = None, + concat: bool = False, +): + if image.is_dir() and output.is_file(): + logger.error( + f"invalid --output: when image is a directory, output should be a directory" + ) + exit(-1) + output.mkdir(parents=True, exist_ok=True) + + image_paths = glob_images(image) + mask_paths = glob_images(mask) + if len(image_paths) == 0: + logger.error(f"invalid --image: empty image folder") + exit(-1) + if len(mask_paths) == 0: + logger.error(f"invalid --mask: empty mask folder") + exit(-1) + + if config is None: + inpaint_request = InpaintRequest() + logger.info(f"Using default config: {inpaint_request}") + else: + with open(config, "r", encoding="utf-8") as f: + inpaint_request = InpaintRequest(**json.load(f)) + + model_manager = ModelManager(name=model, device=device) + first_mask = list(mask_paths.values())[0] + + console = Console() + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + console=console, + transient=False, + ) as progress: + task = progress.add_task("Batch processing...", total=len(image_paths)) + for stem, image_p in image_paths.items(): + if stem not in mask_paths and mask.is_dir(): + progress.log(f"mask for {image_p} not found") + progress.update(task, advance=1) + continue + mask_p = mask_paths.get(stem, first_mask) + + infos = Image.open(image_p).info + + img = cv2.imread(str(image_p)) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + if mask_img.shape[:2] != img.shape[:2]: + progress.log( + f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}" + ) + mask_img = cv2.resize( + mask_img, + (img.shape[1], img.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + mask_img[mask_img >= 127] = 255 + mask_img[mask_img < 127] = 0 + + # bgr + inpaint_result = model_manager(img, mask_img, inpaint_request) + inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB) + if concat: + mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB) + inpaint_result = cv2.hconcat([img, mask_img, inpaint_result]) + + img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos) + save_p = output / f"{stem}.png" + with open(save_p, "wb") as fw: + fw.write(img_bytes) + + progress.update(task, advance=1) + torch_gc() + # pid = psutil.Process().pid + # memory_info = psutil.Process(pid).memory_info() + # memory_in_mb = memory_info.rss / (1024 * 1024) + # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB") diff --git a/custom-demo/back-end/benchmark.py b/custom-demo/back-end/benchmark.py new file mode 100644 index 0000000..0205c60 --- /dev/null +++ b/custom-demo/back-end/benchmark.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time + +import numpy as np +import nvidia_smi +import psutil +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import InpaintRequest, HDStrategy, SDSampler + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +NUM_THREADS = str(4) + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + + +def run_model(model, size): + # RGB + image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8) + mask = np.random.randint(0, 255, size).astype(np.uint8) + + config = InpaintRequest( + ldm_steps=2, + hd_strategy=HDStrategy.ORIGINAL, + hd_strategy_crop_margin=128, + hd_strategy_crop_trigger_size=128, + hd_strategy_resize_limit=128, + prompt="a fox is sitting on a bench", + sd_steps=5, + sd_sampler=SDSampler.ddim, + ) + model(image, mask, config) + + +def benchmark(model, times: int, empty_cache: bool): + sizes = [(512, 512)] + + nvidia_smi.nvmlInit() + device_id = 0 + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id) + + def format(metrics): + return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}" + + process = psutil.Process(os.getpid()) + # 每个 size 给出显存和内存占用的指标 + for size in sizes: + torch.cuda.empty_cache() + time_metrics = [] + cpu_metrics = [] + memory_metrics = [] + gpu_memory_metrics = [] + for _ in range(times): + start = time.time() + run_model(model, size) + torch.cuda.synchronize() + + # cpu_metrics.append(process.cpu_percent()) + time_metrics.append((time.time() - start) * 1000) + memory_metrics.append(process.memory_info().rss / 1024 / 1024) + gpu_memory_metrics.append( + nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024 + ) + + print(f"size: {size}".center(80, "-")) + # print(f"cpu: {format(cpu_metrics)}") + print(f"latency: {format(time_metrics)}ms") + print(f"memory: {format(memory_metrics)} MB") + print(f"gpu memory: {format(gpu_memory_metrics)} MB") + + nvidia_smi.nvmlShutdown() + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--name") + parser.add_argument("--device", default="cuda", type=str) + parser.add_argument("--times", default=10, type=int) + parser.add_argument("--empty-cache", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args_parser() + device = torch.device(args.device) + model = ModelManager( + name=args.name, + device=device, + disable_nsfw=True, + sd_cpu_textencoder=True, + ) + benchmark(model, args.times, args.empty_cache) diff --git a/custom-demo/back-end/cli.py b/custom-demo/back-end/cli.py new file mode 100644 index 0000000..951fbb4 --- /dev/null +++ b/custom-demo/back-end/cli.py @@ -0,0 +1,223 @@ +import webbrowser +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Dict, Optional + +import typer +from fastapi import FastAPI +from loguru import logger +from typer import Option +from typer_config import use_json_config + +from iopaint.const import * +from iopaint.runtime import setup_model_dir, dump_environment_info, check_device +from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel + +typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False) + + +@typer_app.command(help="Install all plugins dependencies") +def install_plugins_packages(): + from iopaint.installer import install_plugins_package + + install_plugins_package() + + +@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace") +def download( + model: str = Option( + ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting" + ), + model_dir: Path = Option( + DEFAULT_MODEL_DIR, + help=MODEL_DIR_HELP, + file_okay=False, + callback=setup_model_dir, + ), +): + from iopaint.download import cli_download_model + + cli_download_model(model) + + +@typer_app.command(name="list", help="List downloaded models") +def list_model( + model_dir: Path = Option( + DEFAULT_MODEL_DIR, + help=MODEL_DIR_HELP, + file_okay=False, + callback=setup_model_dir, + ), +): + from iopaint.download import scan_models + + scanned_models = scan_models() + for it in scanned_models: + print(it.name) + + +@typer_app.command(help="Batch processing images") +def run( + model: str = Option("lama"), + device: Device = Option(Device.cpu), + image: Path = Option(..., help="Image folders or file path"), + mask: Path = Option( + ..., + help="Mask folders or file path. " + "If it is a directory, the mask images in the directory should have the same name as the original image." + "If it is a file, all images will use this mask." + "Mask will automatically resize to the same size as the original image.", + ), + output: Path = Option(..., help="Output directory or file path"), + config: Path = Option( + None, help="Config file path. You can use dump command to create a base config." + ), + concat: bool = Option( + False, help="Concat original image, mask and output images into one image" + ), + model_dir: Path = Option( + DEFAULT_MODEL_DIR, + help=MODEL_DIR_HELP, + file_okay=False, + callback=setup_model_dir, + ), +): + from iopaint.download import cli_download_model, scan_models + + scanned_models = scan_models() + if model not in [it.name for it in scanned_models]: + logger.info(f"{model} not found in {model_dir}, try to downloading") + cli_download_model(model) + + from iopaint.batch_processing import batch_inpaint + + batch_inpaint(model, device, image, mask, output, config, concat) + + +@typer_app.command(help="Start IOPaint server") +@use_json_config() +def start( + host: str = Option("127.0.0.1"), + port: int = Option(8080), + inbrowser: bool = Option(False, help=INBROWSER_HELP), + model: str = Option( + DEFAULT_MODEL, + help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n" + f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.", + ), + model_dir: Path = Option( + DEFAULT_MODEL_DIR, + help=MODEL_DIR_HELP, + dir_okay=True, + file_okay=False, + callback=setup_model_dir, + ), + low_mem: bool = Option(False, help=LOW_MEM_HELP), + no_half: bool = Option(False, help=NO_HALF_HELP), + cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP), + disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP), + cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP), + local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP), + device: Device = Option(Device.cpu), + input: Optional[Path] = Option(None, help=INPUT_HELP), + output_dir: Optional[Path] = Option( + None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False + ), + quality: int = Option(95, help=QUALITY_HELP), + enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP), + interactive_seg_model: InteractiveSegModel = Option( + InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP + ), + interactive_seg_device: Device = Option(Device.cpu), + enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), + remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4), + enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), + enable_realesrgan: bool = Option(False), + realesrgan_device: Device = Option(Device.cpu), + realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3), + enable_gfpgan: bool = Option(False), + gfpgan_device: Device = Option(Device.cpu), + enable_restoreformer: bool = Option(False), + restoreformer_device: Device = Option(Device.cpu), +): + dump_environment_info() + device = check_device(device) + if input and not input.exists(): + logger.error(f"invalid --input: {input} not exists") + exit(-1) + if input and input.is_dir() and not output_dir: + logger.error(f"invalid --output-dir: must be set when --input is a directory") + exit(-1) + if output_dir: + output_dir = output_dir.expanduser().absolute() + logger.info(f"Image will be saved to {output_dir}") + if not output_dir.exists(): + logger.info(f"Create output directory {output_dir}") + output_dir.mkdir(parents=True) + + model_dir = model_dir.expanduser().absolute() + + if local_files_only: + os.environ["TRANSFORMERS_OFFLINE"] = "1" + os.environ["HF_HUB_OFFLINE"] = "1" + + from iopaint.download import cli_download_model, scan_models + + scanned_models = scan_models() + if model not in [it.name for it in scanned_models]: + logger.info(f"{model} not found in {model_dir}, try to downloading") + cli_download_model(model) + + from iopaint.api import Api + from iopaint.schema import ApiConfig + + @asynccontextmanager + async def lifespan(app: FastAPI): + if inbrowser: + webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True) + yield + + app = FastAPI(lifespan=lifespan) + + api_config = ApiConfig( + host=host, + port=port, + inbrowser=inbrowser, + model=model, + no_half=no_half, + low_mem=low_mem, + cpu_offload=cpu_offload, + disable_nsfw_checker=disable_nsfw_checker, + local_files_only=local_files_only, + cpu_textencoder=cpu_textencoder if device == Device.cuda else False, + device=device, + input=input, + output_dir=output_dir, + quality=quality, + enable_interactive_seg=enable_interactive_seg, + interactive_seg_model=interactive_seg_model, + interactive_seg_device=interactive_seg_device, + enable_remove_bg=enable_remove_bg, + remove_bg_model=remove_bg_model, + enable_anime_seg=enable_anime_seg, + enable_realesrgan=enable_realesrgan, + realesrgan_device=realesrgan_device, + realesrgan_model=realesrgan_model, + enable_gfpgan=enable_gfpgan, + gfpgan_device=gfpgan_device, + enable_restoreformer=enable_restoreformer, + restoreformer_device=restoreformer_device, + ) + print(api_config.model_dump_json(indent=4)) + api = Api(app, api_config) + api.launch() + + +@typer_app.command(help="Start IOPaint web config page") +def start_web_config( + config_file: Path = Option("config.json"), +): + dump_environment_info() + from iopaint.web_config import main + + main(config_file) diff --git a/custom-demo/back-end/const.py b/custom-demo/back-end/const.py new file mode 100644 index 0000000..148cb77 --- /dev/null +++ b/custom-demo/back-end/const.py @@ -0,0 +1,120 @@ +import os +from typing import List + +INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix" +KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint" +POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting" +ANYTEXT_NAME = "Sanster/AnyText" + + +DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" +DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" +DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline" +DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline" + +MPS_UNSUPPORT_MODELS = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "cv2", + "manga", +] + +DEFAULT_MODEL = "lama" +AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"] +DIFFUSION_MODELS = [ + "runwayml/stable-diffusion-inpainting", + "Uminosachi/realisticVisionV51_v51VAE-inpainting", + "redstonehero/dreamshaper-inpainting", + "Sanster/anything-4.0-inpainting", + "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + "Fantasy-Studio/Paint-by-Example", + POWERPAINT_NAME, + ANYTEXT_NAME, +] + +NO_HALF_HELP = """ +Using full precision(fp32) model. +If your diffusion model generate result is always black or green, use this argument. +""" + +CPU_OFFLOAD_HELP = """ +Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage. +""" + +LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory." + +DISABLE_NSFW_HELP = """ +Disable NSFW checker for diffusion model. +""" + +CPU_TEXTENCODER_HELP = """ +Run diffusion models text encoder on CPU to reduce vRAM usage. +""" + +SD_CONTROLNET_CHOICES: List[str] = [ + "lllyasviel/control_v11p_sd15_canny", + # "lllyasviel/control_v11p_sd15_seg", + "lllyasviel/control_v11p_sd15_openpose", + "lllyasviel/control_v11p_sd15_inpaint", + "lllyasviel/control_v11f1p_sd15_depth", +] + +SD2_CONTROLNET_CHOICES = [ + "thibaud/controlnet-sd21-canny-diffusers", + "thibaud/controlnet-sd21-depth-diffusers", + "thibaud/controlnet-sd21-openpose-diffusers", +] + +SDXL_CONTROLNET_CHOICES = [ + "thibaud/controlnet-openpose-sdxl-1.0", + "destitech/controlnet-inpaint-dreamer-sdxl", + "diffusers/controlnet-canny-sdxl-1.0", + "diffusers/controlnet-canny-sdxl-1.0-mid", + "diffusers/controlnet-canny-sdxl-1.0-small", + "diffusers/controlnet-depth-sdxl-1.0", + "diffusers/controlnet-depth-sdxl-1.0-mid", + "diffusers/controlnet-depth-sdxl-1.0-small", +] + +LOCAL_FILES_ONLY_HELP = """ +When loading diffusion models, using local files only, not connect to HuggingFace server. +""" + +DEFAULT_MODEL_DIR = os.path.abspath( + os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")) +) + +MODEL_DIR_HELP = f""" +Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR} +""" + +OUTPUT_DIR_HELP = """ +Result images will be saved to output directory automatically. +""" + +INPUT_HELP = """ +If input is image, it will be loaded by default. +If input is directory, you can browse and select image in file manager. +""" + +GUI_HELP = """ +Launch Lama Cleaner as desktop app +""" + +QUALITY_HELP = """ +Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size. +""" + +INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." +INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." +REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU" +ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU" +REALESRGAN_HELP = "Enable realesrgan super resolution" +GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan" +RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan" +GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" + +INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser" diff --git a/custom-demo/back-end/download.py b/custom-demo/back-end/download.py new file mode 100644 index 0000000..2ebd7fc --- /dev/null +++ b/custom-demo/back-end/download.py @@ -0,0 +1,294 @@ +import json +import os +from functools import lru_cache +from typing import List + +from iopaint.schema import ModelType, ModelInfo +from loguru import logger +from pathlib import Path + +from iopaint.const import ( + DEFAULT_MODEL_DIR, + DIFFUSERS_SD_CLASS_NAME, + DIFFUSERS_SD_INPAINT_CLASS_NAME, + DIFFUSERS_SDXL_CLASS_NAME, + DIFFUSERS_SDXL_INPAINT_CLASS_NAME, + ANYTEXT_NAME, +) +from iopaint.model.original_sd_configs import get_config_files + + +def cli_download_model(model: str): + from iopaint.model import models + from iopaint.model.utils import handle_from_pretrained_exceptions + + if model in models and models[model].is_erase_model: + logger.info(f"Downloading {model}...") + models[model].download() + logger.info(f"Done.") + elif model == ANYTEXT_NAME: + logger.info(f"Downloading {model}...") + models[model].download() + logger.info(f"Done.") + else: + logger.info(f"Downloading model from Huggingface: {model}") + from diffusers import DiffusionPipeline + + downloaded_path = handle_from_pretrained_exceptions( + DiffusionPipeline.download, + pretrained_model_name=model, + variant="fp16", + resume_download=True, + ) + logger.info(f"Done. Downloaded to {downloaded_path}") + + +def folder_name_to_show_name(name: str) -> str: + return name.replace("models--", "").replace("--", "/") + + +@lru_cache(maxsize=512) +def get_sd_model_type(model_abs_path: str) -> ModelType: + if "inpaint" in Path(model_abs_path).name.lower(): + model_type = ModelType.DIFFUSERS_SD_INPAINT + else: + # load once to check num_in_channels + from diffusers import StableDiffusionInpaintPipeline + + try: + StableDiffusionInpaintPipeline.from_single_file( + model_abs_path, + load_safety_checker=False, + num_in_channels=9, + config_files=get_config_files(), + ) + model_type = ModelType.DIFFUSERS_SD_INPAINT + except ValueError as e: + if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): + model_type = ModelType.DIFFUSERS_SD + else: + raise e + return model_type + + +@lru_cache() +def get_sdxl_model_type(model_abs_path: str) -> ModelType: + if "inpaint" in model_abs_path: + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + # load once to check num_in_channels + from diffusers import StableDiffusionXLInpaintPipeline + + try: + model = StableDiffusionXLInpaintPipeline.from_single_file( + model_abs_path, + load_safety_checker=False, + num_in_channels=9, + config_files=get_config_files(), + ) + if model.unet.config.in_channels == 9: + # https://github.com/huggingface/diffusers/issues/6610 + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + model_type = ModelType.DIFFUSERS_SDXL + except ValueError as e: + if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): + model_type = ModelType.DIFFUSERS_SDXL + else: + raise e + return model_type + + +def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: + cache_dir = Path(cache_dir) + stable_diffusion_dir = cache_dir / "stable_diffusion" + cache_file = stable_diffusion_dir / "iopaint_cache.json" + model_type_cache = {} + if cache_file.exists(): + try: + with open(cache_file, "r", encoding="utf-8") as f: + model_type_cache = json.load(f) + assert isinstance(model_type_cache, dict) + except: + pass + + res = [] + for it in stable_diffusion_dir.glob(f"*.*"): + if it.suffix not in [".safetensors", ".ckpt"]: + continue + model_abs_path = str(it.absolute()) + model_type = model_type_cache.get(it.name) + if model_type is None: + model_type = get_sd_model_type(model_abs_path) + model_type_cache[it.name] = model_type + res.append( + ModelInfo( + name=it.name, + path=model_abs_path, + model_type=model_type, + is_single_file_diffusers=True, + ) + ) + if stable_diffusion_dir.exists(): + with open(cache_file, "w", encoding="utf-8") as fw: + json.dump(model_type_cache, fw, indent=2, ensure_ascii=False) + + stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" + sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json" + sdxl_model_type_cache = {} + if sdxl_cache_file.exists(): + try: + with open(sdxl_cache_file, "r", encoding="utf-8") as f: + sdxl_model_type_cache = json.load(f) + assert isinstance(sdxl_model_type_cache, dict) + except: + pass + + for it in stable_diffusion_xl_dir.glob(f"*.*"): + if it.suffix not in [".safetensors", ".ckpt"]: + continue + model_abs_path = str(it.absolute()) + model_type = sdxl_model_type_cache.get(it.name) + if model_type is None: + model_type = get_sdxl_model_type(model_abs_path) + sdxl_model_type_cache[it.name] = model_type + if stable_diffusion_xl_dir.exists(): + with open(sdxl_cache_file, "w", encoding="utf-8") as fw: + json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False) + + res.append( + ModelInfo( + name=it.name, + path=model_abs_path, + model_type=model_type, + is_single_file_diffusers=True, + ) + ) + return res + + +def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]: + res = [] + from iopaint.model import models + + # logger.info(f"Scanning inpaint models in {model_dir}") + + for name, m in models.items(): + if m.is_erase_model and m.is_downloaded(): + res.append( + ModelInfo( + name=name, + path=name, + model_type=ModelType.INPAINT, + ) + ) + return res + + +def scan_diffusers_models() -> List[ModelInfo]: + from huggingface_hub.constants import HF_HUB_CACHE + + available_models = [] + cache_dir = Path(HF_HUB_CACHE) + # logger.info(f"Scanning diffusers models in {cache_dir}") + diffusers_model_names = [] + for it in cache_dir.glob("**/*/model_index.json"): + with open(it, "r", encoding="utf-8") as f: + try: + data = json.load(f) + except: + continue + + _class_name = data["_class_name"] + name = folder_name_to_show_name(it.parent.parent.parent.name) + if name in diffusers_model_names: + continue + if "PowerPaint" in name: + model_type = ModelType.DIFFUSERS_OTHER + elif _class_name == DIFFUSERS_SD_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD + elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD_INPAINT + elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL + elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + elif _class_name in [ + "StableDiffusionInstructPix2PixPipeline", + "PaintByExamplePipeline", + "KandinskyV22InpaintPipeline", + "AnyText", + ]: + model_type = ModelType.DIFFUSERS_OTHER + else: + continue + + diffusers_model_names.append(name) + available_models.append( + ModelInfo( + name=name, + path=name, + model_type=model_type, + ) + ) + return available_models + + +def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]: + cache_dir = Path(cache_dir) + available_models = [] + diffusers_model_names = [] + for it in cache_dir.glob("**/*/model_index.json"): + with open(it, "r", encoding="utf-8") as f: + try: + data = json.load(f) + except: + logger.error( + f"Failed to load {it}, please try revert from original model or fix model_index.json by hand." + ) + continue + + _class_name = data["_class_name"] + name = folder_name_to_show_name(it.parent.name) + if name in diffusers_model_names: + continue + elif _class_name == DIFFUSERS_SD_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD + elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD_INPAINT + elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL + elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + continue + + diffusers_model_names.append(name) + available_models.append( + ModelInfo( + name=name, + path=str(it.parent.absolute()), + model_type=model_type, + ) + ) + return available_models + + +def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]: + cache_dir = Path(cache_dir) + available_models = [] + stable_diffusion_dir = cache_dir / "stable_diffusion" + stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" + available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir)) + available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir)) + return available_models + + +def scan_models() -> List[ModelInfo]: + model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR) + available_models = [] + available_models.extend(scan_inpaint_models(model_dir)) + available_models.extend(scan_single_file_diffusion_models(model_dir)) + available_models.extend(scan_diffusers_models()) + available_models.extend(scan_converted_diffusers_models(model_dir)) + return available_models diff --git a/custom-demo/back-end/file_manager/__init__.py b/custom-demo/back-end/file_manager/__init__.py new file mode 100644 index 0000000..1a24998 --- /dev/null +++ b/custom-demo/back-end/file_manager/__init__.py @@ -0,0 +1 @@ +from .file_manager import FileManager diff --git a/custom-demo/back-end/file_manager/file_manager.py b/custom-demo/back-end/file_manager/file_manager.py new file mode 100644 index 0000000..413162c --- /dev/null +++ b/custom-demo/back-end/file_manager/file_manager.py @@ -0,0 +1,215 @@ +import os +from io import BytesIO +from pathlib import Path +from typing import List + +from PIL import Image, ImageOps, PngImagePlugin +from fastapi import FastAPI, UploadFile, HTTPException +from starlette.responses import FileResponse + +from ..schema import MediasResponse, MediaTab + +LARGE_ENOUGH_NUMBER = 100 +PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) +from .storage_backends import FilesystemStorageBackend +from .utils import aspect_to_string, generate_filename, glob_img + + +class FileManager: + def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path): + self.app = app + self.input_dir: Path = input_dir + self.output_dir: Path = output_dir + + self.image_dir_filenames = [] + self.output_dir_filenames = [] + if not self.thumbnail_directory.exists(): + self.thumbnail_directory.mkdir(parents=True) + + # fmt: off + self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse]) + self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"]) + self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"]) + # fmt: on + + def api_medias(self, tab: MediaTab) -> List[MediasResponse]: + img_dir = self._get_dir(tab) + return self._media_names(img_dir) + + def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse: + file_path = self._get_file(tab, filename) + return FileResponse(file_path, media_type="image/png") + + # tab=${tab}?filename=${filename.name}?width=${width}&height=${height} + def api_media_thumbnail_file( + self, tab: MediaTab, filename: str, width: int, height: int + ) -> FileResponse: + img_dir = self._get_dir(tab) + thumb_filename, (width, height) = self.get_thumbnail( + img_dir, filename, width=width, height=height + ) + thumbnail_filepath = self.thumbnail_directory / thumb_filename + return FileResponse( + thumbnail_filepath, + headers={ + "X-Width": str(width), + "X-Height": str(height), + }, + media_type="image/jpeg", + ) + + def _get_dir(self, tab: MediaTab) -> Path: + if tab == "input": + return self.input_dir + elif tab == "output": + return self.output_dir + else: + raise HTTPException(status_code=422, detail=f"tab not found: {tab}") + + def _get_file(self, tab: MediaTab, filename: str) -> Path: + file_path = self._get_dir(tab) / filename + if not file_path.exists(): + raise HTTPException(status_code=422, detail=f"file not found: {file_path}") + return file_path + + @property + def thumbnail_directory(self) -> Path: + return self.output_dir / "thumbnails" + + @staticmethod + def _media_names(directory: Path) -> List[MediasResponse]: + names = sorted([it.name for it in glob_img(directory)]) + res = [] + for name in names: + path = os.path.join(directory, name) + img = Image.open(path) + res.append( + MediasResponse( + name=name, + height=img.height, + width=img.width, + ctime=os.path.getctime(path), + mtime=os.path.getmtime(path), + ) + ) + return res + + def get_thumbnail( + self, directory: Path, original_filename: str, width, height, **options + ): + directory = Path(directory) + storage = FilesystemStorageBackend(self.app) + crop = options.get("crop", "fit") + background = options.get("background") + quality = options.get("quality", 90) + + original_path, original_filename = os.path.split(original_filename) + original_filepath = os.path.join(directory, original_path, original_filename) + image = Image.open(BytesIO(storage.read(original_filepath))) + + # keep ratio resize + if not width and not height: + width = 256 + + if width != 0: + height = int(image.height * width / image.width) + else: + width = int(image.width * height / image.height) + + thumbnail_size = (width, height) + + thumbnail_filename = generate_filename( + directory, + original_filename, + aspect_to_string(thumbnail_size), + crop, + background, + quality, + ) + + thumbnail_filepath = os.path.join( + self.thumbnail_directory, original_path, thumbnail_filename + ) + + if storage.exists(thumbnail_filepath): + return thumbnail_filepath, (width, height) + + try: + image.load() + except (IOError, OSError): + self.app.logger.warning("Thumbnail not load image: %s", original_filepath) + return thumbnail_filepath, (width, height) + + # get original image format + options["format"] = options.get("format", image.format) + + image = self._create_thumbnail( + image, thumbnail_size, crop, background=background + ) + + raw_data = self.get_raw_data(image, **options) + storage.save(thumbnail_filepath, raw_data) + + return thumbnail_filepath, (width, height) + + def get_raw_data(self, image, **options): + data = { + "format": self._get_format(image, **options), + "quality": options.get("quality", 90), + } + + _file = BytesIO() + image.save(_file, **data) + return _file.getvalue() + + @staticmethod + def colormode(image, colormode="RGB"): + if colormode == "RGB" or colormode == "RGBA": + if image.mode == "RGBA": + return image + if image.mode == "LA": + return image.convert("RGBA") + return image.convert(colormode) + + if colormode == "GRAY": + return image.convert("L") + + return image.convert(colormode) + + @staticmethod + def background(original_image, color=0xFF): + size = (max(original_image.size),) * 2 + image = Image.new("L", size, color) + image.paste( + original_image, + tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))), + ) + + return image + + def _get_format(self, image, **options): + if options.get("format"): + return options.get("format") + if image.format: + return image.format + + return "JPEG" + + def _create_thumbnail(self, image, size, crop="fit", background=None): + try: + resample = Image.Resampling.LANCZOS + except AttributeError: # pylint: disable=raise-missing-from + resample = Image.ANTIALIAS + + if crop == "fit": + image = ImageOps.fit(image, size, resample) + else: + image = image.copy() + image.thumbnail(size, resample=resample) + + if background is not None: + image = self.background(image) + + image = self.colormode(image) + + return image diff --git a/custom-demo/back-end/file_manager/storage_backends.py b/custom-demo/back-end/file_manager/storage_backends.py new file mode 100644 index 0000000..3f453ad --- /dev/null +++ b/custom-demo/back-end/file_manager/storage_backends.py @@ -0,0 +1,46 @@ +# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py +import errno +import os +from abc import ABC, abstractmethod + + +class BaseStorageBackend(ABC): + def __init__(self, app=None): + self.app = app + + @abstractmethod + def read(self, filepath, mode="rb", **kwargs): + raise NotImplementedError + + @abstractmethod + def exists(self, filepath): + raise NotImplementedError + + @abstractmethod + def save(self, filepath, data): + raise NotImplementedError + + +class FilesystemStorageBackend(BaseStorageBackend): + def read(self, filepath, mode="rb", **kwargs): + with open(filepath, mode) as f: # pylint: disable=unspecified-encoding + return f.read() + + def exists(self, filepath): + return os.path.exists(filepath) + + def save(self, filepath, data): + directory = os.path.dirname(filepath) + + if not os.path.exists(directory): + try: + os.makedirs(directory) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + if not os.path.isdir(directory): + raise IOError("{} is not a directory".format(directory)) + + with open(filepath, "wb") as f: + f.write(data) diff --git a/custom-demo/back-end/file_manager/utils.py b/custom-demo/back-end/file_manager/utils.py new file mode 100644 index 0000000..f6890af --- /dev/null +++ b/custom-demo/back-end/file_manager/utils.py @@ -0,0 +1,65 @@ +# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py +import hashlib +from pathlib import Path + +from typing import Union + + +def generate_filename(directory: Path, original_filename, *options) -> str: + text = str(directory.absolute()) + original_filename + for v in options: + text += "%s" % v + md5_hash = hashlib.md5() + md5_hash.update(text.encode("utf-8")) + return md5_hash.hexdigest() + ".jpg" + + +def parse_size(size): + if isinstance(size, int): + # If the size parameter is a single number, assume square aspect. + return [size, size] + + if isinstance(size, (tuple, list)): + if len(size) == 1: + # If single value tuple/list is provided, exand it to two elements + return size + type(size)(size) + return size + + try: + thumbnail_size = [int(x) for x in size.lower().split("x", 1)] + except ValueError: + raise ValueError( # pylint: disable=raise-missing-from + "Bad thumbnail size format. Valid format is INTxINT." + ) + + if len(thumbnail_size) == 1: + # If the size parameter only contains a single integer, assume square aspect. + thumbnail_size.append(thumbnail_size[0]) + + return thumbnail_size + + +def aspect_to_string(size): + if isinstance(size, str): + return size + + return "x".join(map(str, size)) + + +IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"} + + +def glob_img(p: Union[Path, str], recursive: bool = False): + p = Path(p) + if p.is_file() and p.suffix in IMG_SUFFIX: + yield p + else: + if recursive: + files = Path(p).glob("**/*.*") + else: + files = Path(p).glob("*.*") + + for it in files: + if it.suffix not in IMG_SUFFIX: + continue + yield it diff --git a/custom-demo/back-end/helper.py b/custom-demo/back-end/helper.py new file mode 100644 index 0000000..1c99dcf --- /dev/null +++ b/custom-demo/back-end/helper.py @@ -0,0 +1,408 @@ +import base64 +import imghdr +import io +import os +import sys +from typing import List, Optional, Dict, Tuple + +from urllib.parse import urlparse +import cv2 +from PIL import Image, ImageOps, PngImagePlugin +import numpy as np +import torch +from iopaint.const import MPS_UNSUPPORT_MODELS +from loguru import logger +from torch.hub import download_url_to_file, get_dir +import hashlib + + +def md5sum(filename): + md5 = hashlib.md5() + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(128 * md5.block_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def switch_mps_device(model_name, device): + if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps": + logger.info(f"{model_name} not support mps, switch to cpu") + return torch.device("cpu") + return device + + +def get_cache_path_by_url(url): + parts = urlparse(url) + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + if not os.path.isdir(model_dir): + os.makedirs(model_dir) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + return cached_file + + +def download_model(url, model_md5: str = None): + if os.path.exists(url): + cached_file = url + else: + cached_file = get_cache_path_by_url(url) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + download_url_to_file(url, cached_file, hash_prefix, progress=True) + if model_md5: + _md5 = md5sum(cached_file) + if model_md5 == _md5: + logger.info(f"Download model success, md5: {_md5}") + else: + try: + os.remove(cached_file) + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint." + f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" + ) + except: + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint." + ) + exit(-1) + + return cached_file + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def handle_error(model_path, model_md5, e): + _md5 = md5sum(model_path) + if _md5 != model_md5: + try: + os.remove(model_path) + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint." + f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" + ) + except: + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint." + ) + else: + logger.error( + f"Failed to load model {model_path}," + f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" + ) + exit(-1) + + +def load_jit_model(url_or_path, device, model_md5: str): + if os.path.exists(url_or_path): + model_path = url_or_path + else: + model_path = download_model(url_or_path, model_md5) + + logger.info(f"Loading model from: {model_path}") + try: + model = torch.jit.load(model_path, map_location="cpu").to(device) + except Exception as e: + handle_error(model_path, model_md5, e) + model.eval() + return model + + +def load_model(model: torch.nn.Module, url_or_path, device, model_md5): + if os.path.exists(url_or_path): + model_path = url_or_path + else: + model_path = download_model(url_or_path, model_md5) + + try: + logger.info(f"Loading model from: {model_path}") + state_dict = torch.load(model_path, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + model.to(device) + except Exception as e: + handle_error(model_path, model_md5, e) + model.eval() + return model + + +def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: + data = cv2.imencode( + f".{ext}", + image_numpy, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + )[1] + image_bytes = data.tobytes() + return image_bytes + + +def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes: + with io.BytesIO() as output: + kwargs = {k: v for k, v in infos.items() if v is not None} + if ext == "jpg": + ext = "jpeg" + if "png" == ext.lower() and "parameters" in kwargs: + pnginfo_data = PngImagePlugin.PngInfo() + pnginfo_data.add_text("parameters", kwargs["parameters"]) + kwargs["pnginfo"] = pnginfo_data + + pil_img.save(output, format=ext, quality=quality, **kwargs) + image_bytes = output.getvalue() + return image_bytes + + +def load_img(img_bytes, gray: bool = False, return_info: bool = False): + alpha_channel = None + image = Image.open(io.BytesIO(img_bytes)) + + if return_info: + infos = image.info + + try: + image = ImageOps.exif_transpose(image) + except: + pass + + if gray: + image = image.convert("L") + np_img = np.array(image) + else: + if image.mode == "RGBA": + np_img = np.array(image) + alpha_channel = np_img[:, :, -1] + np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) + else: + image = image.convert("RGB") + np_img = np.array(image) + + if return_info: + return np_img, alpha_channel, infos + return np_img, alpha_channel + + +def norm_img(np_img): + if len(np_img.shape) == 2: + np_img = np_img[:, :, np.newaxis] + np_img = np.transpose(np_img, (2, 0, 1)) + np_img = np_img.astype("float32") / 255 + return np_img + + +def resize_max_size( + np_img, size_limit: int, interpolation=cv2.INTER_CUBIC +) -> np.ndarray: + # Resize image's longer size to size_limit if longer size larger than size_limit + h, w = np_img.shape[:2] + if max(h, w) > size_limit: + ratio = size_limit / max(h, w) + new_w = int(w * ratio + 0.5) + new_h = int(h * ratio + 0.5) + return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) + else: + return np_img + + +def pad_img_to_modulo( + img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None +): + """ + + Args: + img: [H, W, C] + mod: + square: 是否为正方形 + min_size: + + Returns: + + """ + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + height, width = img.shape[:2] + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + + if min_size is not None: + assert min_size % mod == 0 + out_width = max(min_size, out_width) + out_height = max(min_size, out_height) + + if square: + max_size = max(out_height, out_width) + out_height = max_size + out_width = max_size + + return np.pad( + img, + ((0, out_height - height), (0, out_width - width), (0, 0)), + mode="symmetric", + ) + + +def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (h, w, 1) 0~255 + + Returns: + + """ + height, width = mask.shape[:2] + _, thresh = cv2.threshold(mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + boxes = [] + for cnt in contours: + x, y, w, h = cv2.boundingRect(cnt) + box = np.array([x, y, x + w, y + h]).astype(int) + + box[::2] = np.clip(box[::2], 0, width) + box[1::2] = np.clip(box[1::2], 0, height) + boxes.append(box) + + return boxes + + +def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (h, w) 0~255 + + Returns: + + """ + _, thresh = cv2.threshold(mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + max_area = 0 + max_index = -1 + for i, cnt in enumerate(contours): + area = cv2.contourArea(cnt) + if area > max_area: + max_area = area + max_index = i + + if max_index != -1: + new_mask = np.zeros_like(mask) + return cv2.drawContours(new_mask, contours, max_index, 255, -1) + else: + return mask + + +def is_mac(): + return sys.platform == "darwin" + + +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + +def decode_base64_to_image( + encoding: str, gray=False +) -> Tuple[np.array, Optional[np.array], Dict]: + if encoding.startswith("data:image/") or encoding.startswith( + "data:application/octet-stream;base64," + ): + encoding = encoding.split(";")[1].split(",")[1] + image = Image.open(io.BytesIO(base64.b64decode(encoding))) + + alpha_channel = None + try: + image = ImageOps.exif_transpose(image) + except: + pass + # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose + infos = image.info + + if gray: + image = image.convert("L") + np_img = np.array(image) + else: + if image.mode == "RGBA": + np_img = np.array(image) + alpha_channel = np_img[:, :, -1] + np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) + else: + image = image.convert("RGB") + np_img = np.array(image) + + return np_img, alpha_channel, infos + + +def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes: + img_bytes = pil_to_bytes( + image, + "png", + quality=quality, + infos=infos, + ) + return base64.b64encode(img_bytes) + + +def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray: + if alpha_channel is not None: + if alpha_channel.shape[:2] != rgb_np_img.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0]) + ) + rgb_np_img = np.concatenate( + (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + return rgb_np_img + + +def adjust_mask(mask: np.ndarray, kernel_size: int, operate): + # fronted brush color "ffcc00bb" + # kernel_size = kernel_size*2+1 + mask[mask >= 127] = 255 + mask[mask < 127] = 0 + + if operate == "reverse": + mask = 255 - mask + else: + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1) + ) + if operate == "expand": + mask = cv2.dilate( + mask, + kernel, + iterations=1, + ) + else: + mask = cv2.erode( + mask, + kernel, + iterations=1, + ) + res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) + res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)] + res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) + return res_mask + + +def gen_frontend_mask(bgr_or_gray_mask): + if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1: + bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY) + + # fronted brush color "ffcc00bb" + # TODO: how to set kernel size? + kernel_size = 9 + bgr_or_gray_mask = cv2.dilate( + bgr_or_gray_mask, + np.ones((kernel_size, kernel_size), np.uint8), + iterations=1, + ) + res_mask = np.zeros( + (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8 + ) + res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)] + res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) + return res_mask diff --git a/custom-demo/back-end/installer.py b/custom-demo/back-end/installer.py new file mode 100644 index 0000000..f255e33 --- /dev/null +++ b/custom-demo/back-end/installer.py @@ -0,0 +1,12 @@ +import subprocess +import sys + + +def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +def install_plugins_package(): + install("rembg") + install("realesrgan") + install("gfpgan") diff --git a/custom-demo/back-end/main.py b/custom-demo/back-end/main.py new file mode 100644 index 0000000..807c785 --- /dev/null +++ b/custom-demo/back-end/main.py @@ -0,0 +1,22 @@ +# __init__.py +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068 +os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1" +os.environ["LRU_CACHE_CAPACITY"] = "1" +# prevent CPU memory leak when run model on GPU +# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431 +# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633 +os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1" + + +import warnings + +warnings.simplefilter("ignore", UserWarning) +from iopaint.cli import typer_app + + + +if __name__ == "__main__": + typer_app() diff --git a/custom-demo/back-end/model/__init__.py b/custom-demo/back-end/model/__init__.py new file mode 100644 index 0000000..799e2ec --- /dev/null +++ b/custom-demo/back-end/model/__init__.py @@ -0,0 +1,37 @@ +from .anytext.anytext_model import AnyText +from .controlnet import ControlNet +from .fcf import FcF +from .instruct_pix2pix import InstructPix2Pix +from .kandinsky import Kandinsky22 +from .lama import LaMa +from .ldm import LDM +from .manga import Manga +from .mat import MAT +from .mi_gan import MIGAN +from .opencv2 import OpenCV2 +from .paint_by_example import PaintByExample +from .power_paint.power_paint import PowerPaint +from .sd import SD15, SD2, Anything4, RealisticVision14, SD +from .sdxl import SDXL +from .zits import ZITS + +models = { + LaMa.name: LaMa, + LDM.name: LDM, + ZITS.name: ZITS, + MAT.name: MAT, + FcF.name: FcF, + OpenCV2.name: OpenCV2, + Manga.name: Manga, + MIGAN.name: MIGAN, + SD15.name: SD15, + Anything4.name: Anything4, + RealisticVision14.name: RealisticVision14, + SD2.name: SD2, + PaintByExample.name: PaintByExample, + InstructPix2Pix.name: InstructPix2Pix, + Kandinsky22.name: Kandinsky22, + SDXL.name: SDXL, + PowerPaint.name: PowerPaint, + AnyText.name: AnyText, +} diff --git a/custom-demo/back-end/model/anytext/__init__.py b/custom-demo/back-end/model/anytext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/anytext_model.py b/custom-demo/back-end/model/anytext/anytext_model.py new file mode 100644 index 0000000..374669e --- /dev/null +++ b/custom-demo/back-end/model/anytext/anytext_model.py @@ -0,0 +1,73 @@ +import torch +from huggingface_hub import hf_hub_download + +from iopaint.const import ANYTEXT_NAME +from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline +from iopaint.model.base import DiffusionInpaintModel +from iopaint.model.utils import get_torch_dtype, is_local_files_only +from iopaint.schema import InpaintRequest + + +class AnyText(DiffusionInpaintModel): + name = ANYTEXT_NAME + pad_mod = 64 + is_erase_model = False + + @staticmethod + def download(local_files_only=False): + hf_hub_download( + repo_id=ANYTEXT_NAME, + filename="model_index.json", + local_files_only=local_files_only, + ) + ckpt_path = hf_hub_download( + repo_id=ANYTEXT_NAME, + filename="pytorch_model.fp16.safetensors", + local_files_only=local_files_only, + ) + font_path = hf_hub_download( + repo_id=ANYTEXT_NAME, + filename="SourceHanSansSC-Medium.otf", + local_files_only=local_files_only, + ) + return ckpt_path, font_path + + def init_model(self, device, **kwargs): + local_files_only = is_local_files_only(**kwargs) + ckpt_path, font_path = self.download(local_files_only) + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + self.model = AnyTextPipeline( + ckpt_path=ckpt_path, + font_path=font_path, + device=device, + use_fp16=torch_dtype == torch.float16, + ) + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to inpainting + return: BGR IMAGE + """ + height, width = image.shape[:2] + mask = mask.astype("float32") / 255.0 + masked_image = image * (1 - mask) + + # list of rgb ndarray + results, rtn_code, rtn_warning = self.model( + image=image, + masked_image=masked_image, + prompt=config.prompt, + negative_prompt=config.negative_prompt, + num_inference_steps=config.sd_steps, + strength=config.sd_strength, + guidance_scale=config.sd_guidance_scale, + height=height, + width=width, + seed=config.sd_seed, + sort_priority="y", + callback=self.callback + ) + inpainted_rgb_image = results[0][..., ::-1] + return inpainted_rgb_image diff --git a/custom-demo/back-end/model/anytext/anytext_pipeline.py b/custom-demo/back-end/model/anytext/anytext_pipeline.py new file mode 100644 index 0000000..5051272 --- /dev/null +++ b/custom-demo/back-end/model/anytext/anytext_pipeline.py @@ -0,0 +1,403 @@ +""" +AnyText: Multilingual Visual Text Generation And Editing +Paper: https://arxiv.org/abs/2311.03054 +Code: https://github.com/tyxsspa/AnyText +Copyright (c) Alibaba, Inc. and its affiliates. +""" +import os +from pathlib import Path + +from iopaint.model.utils import set_seed +from safetensors.torch import load_file + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import torch +import re +import numpy as np +import cv2 +import einops +from PIL import ImageFont +from iopaint.model.anytext.cldm.model import create_model, load_state_dict +from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler +from iopaint.model.anytext.utils import ( + check_channels, + draw_glyph, + draw_glyph2, +) + + +BBOX_MAX_NUM = 8 +PLACE_HOLDER = "*" +max_chars = 20 + +ANYTEXT_CFG = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml" +) + + +def check_limits(tensor): + float16_min = torch.finfo(torch.float16).min + float16_max = torch.finfo(torch.float16).max + + # 检查张量中是否有值小于float16的最小值或大于float16的最大值 + is_below_min = (tensor < float16_min).any() + is_above_max = (tensor > float16_max).any() + + return is_below_min or is_above_max + + +class AnyTextPipeline: + def __init__(self, ckpt_path, font_path, device, use_fp16=True): + self.cfg_path = ANYTEXT_CFG + self.font_path = font_path + self.use_fp16 = use_fp16 + self.device = device + + self.font = ImageFont.truetype(font_path, size=60) + self.model = create_model( + self.cfg_path, + device=self.device, + use_fp16=self.use_fp16, + ) + if self.use_fp16: + self.model = self.model.half() + if Path(ckpt_path).suffix == ".safetensors": + state_dict = load_file(ckpt_path, device="cpu") + else: + state_dict = load_state_dict(ckpt_path, location="cpu") + self.model.load_state_dict(state_dict, strict=False) + self.model = self.model.eval().to(self.device) + self.ddim_sampler = DDIMSampler(self.model, device=self.device) + + def __call__( + self, + prompt: str, + negative_prompt: str, + image: np.ndarray, + masked_image: np.ndarray, + num_inference_steps: int, + strength: float, + guidance_scale: float, + height: int, + width: int, + seed: int, + sort_priority: str = "y", + callback=None, + ): + """ + + Args: + prompt: + negative_prompt: + image: + masked_image: + num_inference_steps: + strength: + guidance_scale: + height: + width: + seed: + sort_priority: x: left-right, y: top-down + + Returns: + result: list of images in numpy.ndarray format + rst_code: 0: normal -1: error 1:warning + rst_info: string of error or warning + + """ + set_seed(seed) + str_warning = "" + + mode = "text-editing" + revise_pos = False + img_count = 1 + ddim_steps = num_inference_steps + w = width + h = height + strength = strength + cfg_scale = guidance_scale + eta = 0.0 + + prompt, texts = self.modify_prompt(prompt) + if prompt is None and texts is None: + return ( + None, + -1, + "You have input Chinese prompt but the translator is not loaded!", + "", + ) + n_lines = len(texts) + if mode in ["text-generation", "gen"]: + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode in ["text-editing", "edit"]: + if masked_image is None or image is None: + return ( + None, + -1, + "Reference image and position image are needed for text editing!", + "", + ) + if isinstance(image, str): + image = cv2.imread(image)[..., ::-1] + assert image is not None, f"Can't read ori_image image from{image}!" + elif isinstance(image, torch.Tensor): + image = image.cpu().numpy() + else: + assert isinstance( + image, np.ndarray + ), f"Unknown format of ori_image: {type(image)}" + edit_image = image.clip(1, 255) # for mask reason + edit_image = check_channels(edit_image) + # edit_image = resize_image( + # edit_image, max_length=768 + # ) # make w h multiple of 64, resize if w or h > max_length + h, w = edit_image.shape[:2] # change h, w by input ref_img + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if masked_image is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(masked_image, str): + masked_image = cv2.imread(masked_image)[..., ::-1] + assert ( + masked_image is not None + ), f"Can't read draw_pos image from{masked_image}!" + pos_imgs = 255 - masked_image + elif isinstance(masked_image, torch.Tensor): + pos_imgs = masked_image.cpu().numpy() + else: + assert isinstance( + masked_image, np.ndarray + ), f"Unknown format of draw_pos: {type(masked_image)}" + pos_imgs = 255 - masked_image + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # seprate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + raise RuntimeError( + f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images" + ) + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = ( + input_pos[..., np.newaxis] + if len(input_pos.shape) == 2 + else input_pos + ) + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + np_hint = np.sum(pre_pos, axis=0).clip(0, 1) + # prepare info dict + info = {} + info["glyphs"] = [] + info["gly_line"] = [] + info["positions"] = [] + info["n_lines"] = [len(texts)] * img_count + gly_pos_imgs = [] + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = ( + f'"{text}" length > max_chars: {max_chars}, will be cut off...' + ) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = draw_glyph(self.font, text) + glyphs = draw_glyph2( + self.font, + text, + poly_list[i], + scale=gly_scale, + width=w, + height=h, + add_space=False, + ) + gly_pos_img = cv2.drawContours( + glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1 + ) + if revise_pos: + resize_gly = cv2.resize( + glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]) + ) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones( + (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), + dtype=np.uint8, + ), + iterations=1, + ) + new_pos = ( + new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + ) + contours, _ = cv2.findContours( + new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + pre_pos[i] = ( + cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 + ) + gly_pos_img = cv2.drawContours( + glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1 + ) + gly_pos_imgs += [gly_pos_img] # for show + else: + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) + gly_pos_imgs += [ + np.zeros((h * gly_scale, w * gly_scale, 1)) + ] # for show + pos = pre_pos[i] + info["glyphs"] += [self.arr2tensor(glyphs, img_count)] + info["gly_line"] += [self.arr2tensor(gly_line, img_count)] + info["positions"] += [self.arr2tensor(pos, img_count)] + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) + if self.use_fp16: + masked_img = masked_img.half() + encoder_posterior = self.model.encode_first_stage(masked_img[None, ...]) + masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach() + if self.use_fp16: + masked_x = masked_x.half() + info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0) + + hint = self.arr2tensor(np_hint, img_count) + cond = self.model.get_learned_conditioning( + dict( + c_concat=[hint], + c_crossattn=[[prompt] * img_count], + text_info=info, + ) + ) + un_cond = self.model.get_learned_conditioning( + dict( + c_concat=[hint], + c_crossattn=[[negative_prompt] * img_count], + text_info=info, + ) + ) + shape = (4, h // 8, w // 8) + self.model.control_scales = [strength] * 13 + samples, intermediates = self.ddim_sampler.sample( + ddim_steps, + img_count, + shape, + cond, + verbose=False, + eta=eta, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=un_cond, + callback=callback + ) + if self.use_fp16: + samples = samples.half() + x_samples = self.model.decode_first_stage(samples) + x_samples = ( + (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5) + .cpu() + .numpy() + .clip(0, 255) + .astype(np.uint8) + ) + results = [x_samples[i] for i in range(img_count)] + # if ( + # mode == "edit" and False + # ): # replace backgound in text editing but not ideal yet + # results = [r * np_hint + edit_image * (1 - np_hint) for r in results] + # results = [r.clip(0, 255).astype(np.uint8) for r in results] + # if len(gly_pos_imgs) > 0 and show_debug: + # glyph_bs = np.stack(gly_pos_imgs, axis=2) + # glyph_img = np.sum(glyph_bs, axis=2) * 255 + # glyph_img = glyph_img.clip(0, 255).astype(np.uint8) + # results += [np.repeat(glyph_img, 3, axis=2)] + rst_code = 1 if str_warning else 0 + return results, rst_code, str_warning + + def modify_prompt(self, prompt): + prompt = prompt.replace("“", '"') + prompt = prompt.replace("”", '"') + p = '"(.*?)"' + strs = re.findall(p, prompt) + if len(strs) == 0: + strs = [" "] + else: + for s in strs: + prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) + # if self.is_chinese(prompt): + # if self.trans_pipe is None: + # return None, None + # old_prompt = prompt + # prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] + # print(f"Translate: {old_prompt} --> {prompt}") + return prompt, strs + + # def is_chinese(self, text): + # text = checker._clean_text(text) + # for char in text: + # cp = ord(char) + # if checker._is_chinese_char(cp): + # return True + # return False + + def separate_pos_imgs(self, img, sort_priority, gap=102): + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) + components = [] + for label in range(1, num_labels): + component = np.zeros_like(img) + component[labels == label] = 255 + components.append((component, centroids[label])) + if sort_priority == "y": + fir, sec = 1, 0 # top-down first + elif sort_priority == "x": + fir, sec = 0, 1 # left-right first + components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) + sorted_components = [c[0] for c in components] + return sorted_components + + def find_polygon(self, image, min_rect=False): + contours, hierarchy = cv2.findContours( + image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + max_contour = max(contours, key=cv2.contourArea) # get contour with max area + if min_rect: + # get minimum enclosing rectangle + rect = cv2.minAreaRect(max_contour) + poly = np.int0(cv2.boxPoints(rect)) + else: + # get approximate polygon + epsilon = 0.01 * cv2.arcLength(max_contour, True) + poly = cv2.approxPolyDP(max_contour, epsilon, True) + n, _, xy = poly.shape + poly = poly.reshape(n, xy) + cv2.drawContours(image, [poly], -1, 255, -1) + return poly, image + + def arr2tensor(self, arr, bs): + arr = np.transpose(arr, (2, 0, 1)) + _arr = torch.from_numpy(arr.copy()).float().to(self.device) + if self.use_fp16: + _arr = _arr.half() + _arr = torch.stack([_arr for _ in range(bs)], dim=0) + return _arr diff --git a/custom-demo/back-end/model/anytext/anytext_sd15.yaml b/custom-demo/back-end/model/anytext/anytext_sd15.yaml new file mode 100644 index 0000000..d727594 --- /dev/null +++ b/custom-demo/back-end/model/anytext/anytext_sd15.yaml @@ -0,0 +1,99 @@ +model: + target: iopaint.model.anytext.cldm.cldm.ControlLDM + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "img" + cond_stage_key: "caption" + control_key: "hint" + glyph_key: "glyphs" + position_key: "positions" + image_size: 64 + channels: 4 + cond_stage_trainable: true # need be true when embedding_manager is valid + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + only_mid_control: False + loss_alpha: 0 # perceptual loss, 0.003 + loss_beta: 0 # ctc loss + latin_weight: 1.0 # latin text line may need smaller weigth + with_step_weight: true + use_vae_upsample: true + embedding_manager_config: + target: iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager + params: + valid: true # v6 + emb_type: ocr # ocr, vit, conv + glyph_channels: 1 + position_channels: 1 + add_pos: false + placeholder_string: '*' + + control_stage_config: + target: iopaint.model.anytext.cldm.cldm.ControlNet + params: + image_size: 32 # unused + in_channels: 4 + model_channels: 320 + glyph_channels: 1 + position_channels: 1 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + unet_config: + target: iopaint.model.anytext.cldm.cldm.ControlledUnetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3 + params: + version: openai/clip-vit-large-patch14 + use_vision: false # v6 diff --git a/custom-demo/back-end/model/anytext/cldm/__init__.py b/custom-demo/back-end/model/anytext/cldm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/cldm/cldm.py b/custom-demo/back-end/model/anytext/cldm/cldm.py new file mode 100644 index 0000000..ad9692a --- /dev/null +++ b/custom-demo/back-end/model/anytext/cldm/cldm.py @@ -0,0 +1,630 @@ +import os +from pathlib import Path + +import einops +import torch +import torch as th +import torch.nn as nn +import copy +from easydict import EasyDict as edict + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, +) + +from einops import rearrange, repeat +from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer +from iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion +from iopaint.model.anytext.ldm.util import log_txt_as_img, exists, instantiate_from_config +from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler +from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from .recognizer import TextRecognizer, create_predictor + +CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__))) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class ControlledUnetModel(UNetModel): + def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): + hs = [] + with torch.no_grad(): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + if self.use_fp16: + t_emb = t_emb.half() + emb = self.time_embed(t_emb) + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + + if control is not None: + h += control.pop() + + for i, module in enumerate(self.output_blocks): + if only_mid_control or control is None: + h = torch.cat([h, hs.pop()], dim=1) + else: + h = torch.cat([h, hs.pop() + control.pop()], dim=1) + h = module(h, emb, context) + + h = h.type(x.dtype) + return self.out(h) + + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + glyph_channels, + position_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.use_fp16 = use_fp16 + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self.glyph_block = TimestepEmbedSequential( + conv_nd(dims, glyph_channels, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 16, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.position_block = TimestepEmbedSequential( + conv_nd(dims, position_channels, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 8, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 8, 16, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 64, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1)) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, text_info, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + if self.use_fp16: + t_emb = t_emb.half() + emb = self.time_embed(t_emb) + + # guided_hint from text_info + B, C, H, W = x.shape + glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True) + enc_glyph = self.glyph_block(glyphs, emb, context) + enc_pos = self.position_block(positions, emb, context) + guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1)) + + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + + +class ControlLDM(LatentDiffusion): + + def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs): + self.use_fp16 = kwargs.pop('use_fp16', False) + super().__init__(*args, **kwargs) + self.control_model = instantiate_from_config(control_stage_config) + self.control_key = control_key + self.glyph_key = glyph_key + self.position_key = position_key + self.only_mid_control = only_mid_control + self.control_scales = [1.0] * 13 + self.loss_alpha = loss_alpha + self.loss_beta = loss_beta + self.with_step_weight = with_step_weight + self.use_vae_upsample = use_vae_upsample + self.latin_weight = latin_weight + + if embedding_manager_config is not None and embedding_manager_config.params.valid: + self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model) + for param in self.embedding_manager.embedding_parameters(): + param.requires_grad = True + else: + self.embedding_manager = None + if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager: + if embedding_manager_config.params.emb_type == 'ocr': + self.text_predictor = create_predictor().eval() + args = edict() + args.rec_image_shape = "3, 48, 320" + args.rec_batch_num = 6 + args.rec_char_dict_path = str(CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt") + args.use_fp16 = self.use_fp16 + self.cn_recognizer = TextRecognizer(args, self.text_predictor) + for param in self.text_predictor.parameters(): + param.requires_grad = False + if self.embedding_manager: + self.embedding_manager.recog = self.cn_recognizer + + @torch.no_grad() + def get_input(self, batch, k, bs=None, *args, **kwargs): + if self.embedding_manager is None: # fill in full caption + self.fill_caption(batch) + x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs) + control = batch[self.control_key] # for log_images and loss_alpha, not real control + if bs is not None: + control = control[:bs] + control = control.to(self.device) + control = einops.rearrange(control, 'b h w c -> b c h w') + control = control.to(memory_format=torch.contiguous_format).float() + + inv_mask = batch['inv_mask'] + if bs is not None: + inv_mask = inv_mask[:bs] + inv_mask = inv_mask.to(self.device) + inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w') + inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float() + + glyphs = batch[self.glyph_key] + gly_line = batch['gly_line'] + positions = batch[self.position_key] + n_lines = batch['n_lines'] + language = batch['language'] + texts = batch['texts'] + assert len(glyphs) == len(positions) + for i in range(len(glyphs)): + if bs is not None: + glyphs[i] = glyphs[i][:bs] + gly_line[i] = gly_line[i][:bs] + positions[i] = positions[i][:bs] + n_lines = n_lines[:bs] + glyphs[i] = glyphs[i].to(self.device) + gly_line[i] = gly_line[i].to(self.device) + positions[i] = positions[i].to(self.device) + glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w') + gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w') + positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w') + glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float() + gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float() + positions[i] = positions[i].to(memory_format=torch.contiguous_format).float() + info = {} + info['glyphs'] = glyphs + info['positions'] = positions + info['n_lines'] = n_lines + info['language'] = language + info['texts'] = texts + info['img'] = batch['img'] # nhwc, (-1,1) + info['masked_x'] = mx + info['gly_line'] = gly_line + info['inv_mask'] = inv_mask + return x, dict(c_crossattn=[c], c_concat=[control], text_info=info) + + def apply_model(self, x_noisy, t, cond, *args, **kwargs): + assert isinstance(cond, dict) + diffusion_model = self.model.diffusion_model + _cond = torch.cat(cond['c_crossattn'], 1) + _hint = torch.cat(cond['c_concat'], 1) + if self.use_fp16: + x_noisy = x_noisy.half() + control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info']) + control = [c * scale for c, scale in zip(control, self.control_scales)] + eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control) + + return eps + + def instantiate_embedding_manager(self, config, embedder): + model = instantiate_from_config(config, embedder=embedder) + return model + + @torch.no_grad() + def get_unconditional_conditioning(self, N): + return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None)) + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + if self.embedding_manager is not None and c['text_info'] is not None: + self.embedding_manager.encode_text(c['text_info']) + if isinstance(c, dict): + cond_txt = c['c_crossattn'][0] + else: + cond_txt = c + if self.embedding_manager is not None: + cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager) + else: + cond_txt = self.cond_stage_model.encode(cond_txt) + if isinstance(c, dict): + c['c_crossattn'][0] = cond_txt + else: + c = cond_txt + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def fill_caption(self, batch, place_holder='*'): + bs = len(batch['n_lines']) + cond_list = copy.deepcopy(batch[self.cond_stage_key]) + for i in range(bs): + n_lines = batch['n_lines'][i] + if n_lines == 0: + continue + cur_cap = cond_list[i] + for j in range(n_lines): + r_txt = batch['texts'][j][i] + cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1) + cond_list[i] = cur_cap + batch[self.cond_stage_key] = cond_list + + @torch.no_grad() + def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + use_ddim = ddim_steps is not None + + log = dict() + z, c = self.get_input(batch, self.first_stage_key, bs=N) + if self.cond_stage_trainable: + with torch.no_grad(): + c = self.get_learned_conditioning(c) + c_crossattn = c["c_crossattn"][0][:N] + c_cat = c["c_concat"][0][:N] + text_info = c["text_info"] + text_info['glyphs'] = [i[:N] for i in text_info['glyphs']] + text_info['gly_line'] = [i[:N] for i in text_info['gly_line']] + text_info['positions'] = [i[:N] for i in text_info['positions']] + text_info['n_lines'] = text_info['n_lines'][:N] + text_info['masked_x'] = text_info['masked_x'][:N] + text_info['img'] = text_info['img'][:N] + + N = min(z.shape[0], N) + n_row = min(z.shape[0], n_row) + log["reconstruction"] = self.decode_first_stage(z) + log["masked_image"] = self.decode_first_stage(text_info['masked_x']) + log["control"] = c_cat * 2.0 - 1.0 + log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed + # get glyph + glyph_bs = torch.stack(text_info['glyphs']) + glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0 + log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,) + # fill caption + if not self.embedding_manager: + self.fill_caption(batch) + captions = batch[self.cond_stage_key] + log["conditioning"] = log_txt_as_img((512, 512), captions, size=16) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N) + uc_cat = c_cat # torch.zeros_like(c_cat) + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info} + samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + pred_x0 = False # wether log pred_x0 + if pred_x0: + for idx in range(len(tmps['pred_x0'])): + pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx]) + log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0 + + return log + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + ddim_sampler = DDIMSampler(self) + b, c, h, w = cond["c_concat"][0].shape + shape = (self.channels, h // 8, w // 8) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs) + return samples, intermediates + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.control_model.parameters()) + if self.embedding_manager: + params += list(self.embedding_manager.embedding_parameters()) + if not self.sd_locked: + # params += list(self.model.diffusion_model.input_blocks.parameters()) + # params += list(self.model.diffusion_model.middle_block.parameters()) + params += list(self.model.diffusion_model.output_blocks.parameters()) + params += list(self.model.diffusion_model.out.parameters()) + if self.unlockKV: + nCount = 0 + for name, param in self.model.diffusion_model.named_parameters(): + if 'attn2.to_k' in name or 'attn2.to_v' in name: + params += [param] + nCount += 1 + print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!') + + opt = torch.optim.AdamW(params, lr=lr) + return opt + + def low_vram_shift(self, is_diffusing): + if is_diffusing: + self.model = self.model.cuda() + self.control_model = self.control_model.cuda() + self.first_stage_model = self.first_stage_model.cpu() + self.cond_stage_model = self.cond_stage_model.cpu() + else: + self.model = self.model.cpu() + self.control_model = self.control_model.cpu() + self.first_stage_model = self.first_stage_model.cuda() + self.cond_stage_model = self.cond_stage_model.cuda() diff --git a/custom-demo/back-end/model/anytext/cldm/ddim_hacked.py b/custom-demo/back-end/model/anytext/cldm/ddim_hacked.py new file mode 100644 index 0000000..b23a883 --- /dev/null +++ b/custom-demo/back-end/model/anytext/cldm/ddim_hacked.py @@ -0,0 +1,486 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, device, schedule="linear", **kwargs): + super().__init__() + self.device = device + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.device(self.device)) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f"Data shape for DDIM sampling is {size}, eta {eta}") + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + img, pred_x0 = outs + if callback: + callback(None, i, None, None) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + model_output = self.model.apply_model(x, t, c) + else: + model_t = self.model.apply_model(x, t, c) + model_uncond = self.model.apply_model(x, t, unconditional_conditioning) + model_output = model_uncond + unconditional_guidance_scale * ( + model_t - model_uncond + ) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", "not implemented" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + num_reference_steps = timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc="Encoding Image"): + t = torch.full( + (x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long + ) + if unconditional_guidance_scale == 1.0: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model( + torch.cat((x_next, x_next)), + torch.cat((t, t)), + torch.cat((unconditional_conditioning, c)), + ), + 2, + ) + noise_pred = e_t_uncond + unconditional_guidance_scale * ( + noise_pred - e_t_uncond + ) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = ( + alphas_next[i].sqrt() + * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) + * noise_pred + ) + x_next = xt_weighted + weighted_noise_pred + if ( + return_intermediates + and i % (num_steps // return_intermediates) == 0 + and i < num_steps - 1 + ): + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: + callback(i) + + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} + if return_intermediates: + out.update({"intermediates": intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec diff --git a/custom-demo/back-end/model/anytext/cldm/embedding_manager.py b/custom-demo/back-end/model/anytext/cldm/embedding_manager.py new file mode 100644 index 0000000..6ccf8a9 --- /dev/null +++ b/custom-demo/back-end/model/anytext/cldm/embedding_manager.py @@ -0,0 +1,165 @@ +''' +Copyright (c) Alibaba, Inc. and its affiliates. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + token = token[0, 1] + return token + + +def get_clip_vision_emb(encoder, processor, img): + _img = img.repeat(1, 3, 1, 1)*255 + inputs = processor(images=_img, return_tensors="pt") + inputs['pixel_values'] = inputs['pixel_values'].to(img.device) + outputs = encoder(**inputs) + emb = outputs.image_embeds + return emb + + +def get_recog_emb(encoder, img_list): + _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list] + encoder.predictor.eval() + _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) + return preds_neck + + +def pad_H(x): + _, _, H, W = x.shape + p_top = (W - H) // 2 + p_bot = W - H - p_top + return F.pad(x, (0, 0, p_top, p_bot)) + + +class EncodeNet(nn.Module): + def __init__(self, in_channels, out_channels): + super(EncodeNet, self).__init__() + chan = 16 + n_layer = 4 # downsample + + self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1) + self.conv_list = nn.ModuleList([]) + _c = chan + for i in range(n_layer): + self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2)) + _c *= 2 + self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.act = nn.SiLU() + + def forward(self, x): + x = self.act(self.conv1(x)) + for layer in self.conv_list: + x = self.act(layer(x)) + x = self.act(self.conv2(x)) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + valid=True, + glyph_channels=20, + position_channels=1, + placeholder_string='*', + add_pos=False, + emb_type='ocr', + **kwargs + ): + super().__init__() + if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + token_dim = 768 + if hasattr(embedder, 'vit'): + assert emb_type == 'vit' + self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) + self.get_recog_emb = None + else: # using LDM's BERT encoder + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + token_dim = 1280 + self.token_dim = token_dim + self.emb_type = emb_type + + self.add_pos = add_pos + if add_pos: + self.position_encoder = EncodeNet(position_channels, token_dim) + if emb_type == 'ocr': + self.proj = linear(40*64, token_dim) + if emb_type == 'conv': + self.glyph_encoder = EncodeNet(glyph_channels, token_dim) + + self.placeholder_token = get_token_for_string(placeholder_string) + + def encode_text(self, text_info): + if self.get_recog_emb is None and self.emb_type == 'ocr': + self.get_recog_emb = partial(get_recog_emb, self.recog) + + gline_list = [] + pos_list = [] + for i in range(len(text_info['n_lines'])): # sample index in a batch + n_lines = text_info['n_lines'][i] + for j in range(n_lines): # line + gline_list += [text_info['gly_line'][j][i:i+1]] + if self.add_pos: + pos_list += [text_info['positions'][j][i:i+1]] + + if len(gline_list) > 0: + if self.emb_type == 'ocr': + recog_emb = self.get_recog_emb(gline_list) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) + elif self.emb_type == 'vit': + enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0))) + elif self.emb_type == 'conv': + enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0))) + if self.add_pos: + enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) + enc_glyph = enc_glyph+enc_pos + + self.text_embs_all = [] + n_idx = 0 + for i in range(len(text_info['n_lines'])): # sample index in a batch + n_lines = text_info['n_lines'][i] + text_embs = [] + for j in range(n_lines): # line + text_embs += [enc_glyph[n_idx:n_idx+1]] + n_idx += 1 + self.text_embs_all += [text_embs] + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, device = tokenized_text.shape[0], tokenized_text.device + for i in range(b): + idx = tokenized_text[i] == self.placeholder_token.to(device) + if sum(idx) > 0: + if i >= len(self.text_embs_all): + print('truncation for log images...') + break + text_emb = torch.cat(self.text_embs_all[i], dim=0) + if sum(idx) != len(text_emb): + print('truncation for long caption...') + embedded_text[i][idx] = text_emb[:sum(idx)] + return embedded_text + + def embedding_parameters(self): + return self.parameters() diff --git a/custom-demo/back-end/model/anytext/cldm/hack.py b/custom-demo/back-end/model/anytext/cldm/hack.py new file mode 100644 index 0000000..05afe5f --- /dev/null +++ b/custom-demo/back-end/model/anytext/cldm/hack.py @@ -0,0 +1,111 @@ +import torch +import einops + +import iopaint.model.anytext.ldm.modules.encoders.modules +import iopaint.model.anytext.ldm.modules.attention + +from transformers import logging +from iopaint.model.anytext.ldm.modules.attention import default + + +def disable_verbosity(): + logging.set_verbosity_error() + print('logging improved.') + return + + +def enable_sliced_attention(): + iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward + print('Enabled sliced_attention.') + return + + +def hack_everything(clip_skip=0): + disable_verbosity() + iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward + iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip + print('Enabled clip hacks.') + return + + +# Written by Lvmin +def _hacked_clip_forward(self, text): + PAD = self.tokenizer.pad_token_id + EOS = self.tokenizer.eos_token_id + BOS = self.tokenizer.bos_token_id + + def tokenize(t): + return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] + + def transformer_encode(t): + if self.clip_skip > 1: + rt = self.transformer(input_ids=t, output_hidden_states=True) + return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) + else: + return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state + + def split(x): + return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] + + def pad(x, p, i): + return x[:i] if len(x) >= i else x + [p] * (i - len(x)) + + raw_tokens_list = tokenize(text) + tokens_list = [] + + for raw_tokens in raw_tokens_list: + raw_tokens_123 = split(raw_tokens) + raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] + raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] + tokens_list.append(raw_tokens_123) + + tokens_list = torch.IntTensor(tokens_list).to(self.device) + + feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') + y = transformer_encode(feed) + z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) + + return z + + +# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py +def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + limit = k.shape[0] + att_step = 1 + q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) + k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) + v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) + + q_chunks.reverse() + k_chunks.reverse() + v_chunks.reverse() + sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + del k, q, v + for i in range(0, limit, att_step): + q_buffer = q_chunks.pop() + k_buffer = k_chunks.pop() + v_buffer = v_chunks.pop() + sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale + + del k_buffer, q_buffer + # attention, what we cannot get enough of, by chunks + + sim_buffer = sim_buffer.softmax(dim=-1) + + sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) + del v_buffer + sim[i:i + att_step, :, :] = sim_buffer + + del sim_buffer + sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) + return self.to_out(sim) diff --git a/custom-demo/back-end/model/anytext/cldm/model.py b/custom-demo/back-end/model/anytext/cldm/model.py new file mode 100644 index 0000000..688f2ed --- /dev/null +++ b/custom-demo/back-end/model/anytext/cldm/model.py @@ -0,0 +1,40 @@ +import os +import torch + +from omegaconf import OmegaConf +from iopaint.model.anytext.ldm.util import instantiate_from_config + + +def get_state_dict(d): + return d.get("state_dict", d) + + +def load_state_dict(ckpt_path, location="cpu"): + _, extension = os.path.splitext(ckpt_path) + if extension.lower() == ".safetensors": + import safetensors.torch + + state_dict = safetensors.torch.load_file(ckpt_path, device=location) + else: + state_dict = get_state_dict( + torch.load(ckpt_path, map_location=torch.device(location)) + ) + state_dict = get_state_dict(state_dict) + print(f"Loaded state_dict from [{ckpt_path}]") + return state_dict + + +def create_model(config_path, device, cond_stage_path=None, use_fp16=False): + config = OmegaConf.load(config_path) + # if cond_stage_path: + # config.model.params.cond_stage_config.params.version = ( + # cond_stage_path # use pre-downloaded ckpts, in case blocked + # ) + config.model.params.cond_stage_config.params.device = str(device) + if use_fp16: + config.model.params.use_fp16 = True + config.model.params.control_stage_config.params.use_fp16 = True + config.model.params.unet_config.params.use_fp16 = True + model = instantiate_from_config(config.model).cpu() + print(f"Loaded model config from [{config_path}]") + return model diff --git a/custom-demo/back-end/model/anytext/cldm/recognizer.py b/custom-demo/back-end/model/anytext/cldm/recognizer.py new file mode 100755 index 0000000..0621512 --- /dev/null +++ b/custom-demo/back-end/model/anytext/cldm/recognizer.py @@ -0,0 +1,300 @@ +""" +Copyright (c) Alibaba, Inc. and its affiliates. +""" +import os +import cv2 +import numpy as np +import math +import traceback +from easydict import EasyDict as edict +import time +from iopaint.model.anytext.ocr_recog.RecModel import RecModel +import torch +import torch.nn.functional as F + + +def min_bounding_rect(img): + ret, thresh = cv2.threshold(img, 127, 255, 0) + contours, hierarchy = cv2.findContours( + thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + if len(contours) == 0: + print("Bad contours, using fake bbox...") + return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) + max_contour = max(contours, key=cv2.contourArea) + rect = cv2.minAreaRect(max_contour) + box = cv2.boxPoints(rect) + box = np.int0(box) + # sort + x_sorted = sorted(box, key=lambda x: x[0]) + left = x_sorted[:2] + right = x_sorted[2:] + left = sorted(left, key=lambda x: x[1]) + (tl, bl) = left + right = sorted(right, key=lambda x: x[1]) + (tr, br) = right + if tl[1] > bl[1]: + (tl, bl) = (bl, tl) + if tr[1] > br[1]: + (tr, br) = (br, tr) + return np.array([tl, tr, br, bl]) + + +def create_predictor(model_dir=None, model_lang="ch", is_onnx=False): + model_file_path = model_dir + if model_file_path is not None and not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format(model_file_path)) + + if is_onnx: + import onnxruntime as ort + + sess = ort.InferenceSession( + model_file_path, providers=["CPUExecutionProvider"] + ) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' + return sess + else: + if model_lang == "ch": + n_class = 6625 + elif model_lang == "en": + n_class = 97 + else: + raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") + rec_config = edict( + in_channels=3, + backbone=edict( + type="MobileNetV1Enhance", + scale=0.5, + last_conv_stride=[1, 2], + last_pool_type="avg", + ), + neck=edict( + type="SequenceEncoder", + encoder_type="svtr", + dims=64, + depth=2, + hidden_dims=120, + use_guide=True, + ), + head=edict( + type="CTCHead", + fc_decay=0.00001, + out_channels=n_class, + return_feats=True, + ), + ) + + rec_model = RecModel(rec_config) + if model_file_path is not None: + rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu")) + rec_model.eval() + return rec_model.eval() + + +def _check_image_file(path): + img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"} + return any([path.lower().endswith(e) for e in img_end]) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +class TextRecognizer(object): + def __init__(self, args, predictor): + self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] + self.rec_batch_num = args.rec_batch_num + self.predictor = predictor + self.chars = self.get_char_dict(args.rec_char_dict_path) + self.char2id = {x: i for i, x in enumerate(self.chars)} + self.is_onnx = not isinstance(self.predictor, torch.nn.Module) + self.use_fp16 = args.use_fp16 + + # img: CHW + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[0] + imgW = int((imgH * max_wh_ratio)) + + h, w = img.shape[1:] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = torch.nn.functional.interpolate( + img.unsqueeze(0), + size=(imgH, resized_w), + mode="bilinear", + align_corners=True, + ) + resized_image /= 255.0 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) + padding_im[:, :, 0:resized_w] = resized_image[0] + return padding_im + + # img_list: list of tensors with shape chw 0-255 + def pred_imglist(self, img_list, show_debug=False, is_ori=False): + img_num = len(img_list) + assert img_num > 0 + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[2] / float(img.shape[1])) + # Sorting can speed up the recognition process + indices = torch.from_numpy(np.argsort(np.array(width_list))) + batch_num = self.rec_batch_num + preds_all = [None] * img_num + preds_neck_all = [None] * img_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[1:] + if h > w * 1.2: + img = img_list[indices[ino]] + img = torch.transpose(img, 1, 2).flip(dims=[1]) + img_list[indices[ino]] = img + h, w = img.shape[1:] + # wh_ratio = w * 1.0 / h + # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) + if self.use_fp16: + norm_img = norm_img.half() + norm_img = norm_img.unsqueeze(0) + norm_img_batch.append(norm_img) + norm_img_batch = torch.cat(norm_img_batch, dim=0) + if show_debug: + for i in range(len(norm_img_batch)): + _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() + _img = (_img + 0.5) * 255 + _img = _img[:, :, ::-1] + file_name = f"{indices[beg_img_no + i]}" + file_name = file_name + "_ori" if is_ori else file_name + cv2.imwrite(file_name + ".jpg", _img) + if self.is_onnx: + input_dict = {} + input_dict[self.predictor.get_inputs()[0].name] = ( + norm_img_batch.detach().cpu().numpy() + ) + outputs = self.predictor.run(None, input_dict) + preds = {} + preds["ctc"] = torch.from_numpy(outputs[0]) + preds["ctc_neck"] = [torch.zeros(1)] * img_num + else: + preds = self.predictor(norm_img_batch) + for rno in range(preds["ctc"].shape[0]): + preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] + preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] + + return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) + + def get_char_dict(self, character_dict_path): + character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str.append(line) + dict_character = list(character_str) + dict_character = ["sos"] + dict_character + [" "] # eos is space + return dict_character + + def get_text(self, order): + char_list = [self.chars[text_id] for text_id in order] + return "".join(char_list) + + def decode(self, mat): + text_index = mat.detach().cpu().numpy().argmax(axis=1) + ignored_tokens = [0] + selection = np.ones(len(text_index), dtype=bool) + selection[1:] = text_index[1:] != text_index[:-1] + for ignored_token in ignored_tokens: + selection &= text_index != ignored_token + return text_index[selection], np.where(selection)[0] + + def get_ctcloss(self, preds, gt_text, weight): + if not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight).to(preds.device) + ctc_loss = torch.nn.CTCLoss(reduction="none") + log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC + targets = [] + target_lengths = [] + for t in gt_text: + targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] + target_lengths += [len(t)] + targets = torch.tensor(targets).to(preds.device) + target_lengths = torch.tensor(target_lengths).to(preds.device) + input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to( + preds.device + ) + loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) + loss = loss / input_lengths * weight + return loss + + +def main(): + rec_model_dir = "./ocr_weights/ppv3_rec.pth" + predictor = create_predictor(rec_model_dir) + args = edict() + args.rec_image_shape = "3, 48, 320" + args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt" + args.rec_batch_num = 6 + text_recognizer = TextRecognizer(args, predictor) + image_dir = "./test_imgs_cn" + gt_text = ["韩国小馆"] * 14 + + image_file_list = get_image_file_list(image_dir) + valid_image_file_list = [] + img_list = [] + + for image_file in image_file_list: + img = cv2.imread(image_file) + if img is None: + print("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(torch.from_numpy(img).permute(2, 0, 1).float()) + try: + tic = time.time() + times = [] + for i in range(10): + preds, _ = text_recognizer.pred_imglist(img_list) # get text + preds_all = preds.softmax(dim=2) + times += [(time.time() - tic) * 1000.0] + tic = time.time() + print(times) + print(np.mean(times[1:]) / len(preds_all)) + weight = np.ones(len(gt_text)) + loss = text_recognizer.get_ctcloss(preds, gt_text, weight) + for i in range(len(valid_image_file_list)): + pred = preds_all[i] + order, idx = text_recognizer.decode(pred) + text = text_recognizer.get_text(order) + print( + f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}' + ) + except Exception as E: + print(traceback.format_exc(), E) + + +if __name__ == "__main__": + main() diff --git a/custom-demo/back-end/model/anytext/ldm/__init__.py b/custom-demo/back-end/model/anytext/ldm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/models/__init__.py b/custom-demo/back-end/model/anytext/ldm/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/models/autoencoder.py b/custom-demo/back-end/model/anytext/ldm/models/autoencoder.py new file mode 100644 index 0000000..20d52e9 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/autoencoder.py @@ -0,0 +1,218 @@ +import torch +import torch.nn.functional as F +from contextlib import contextmanager + +from iopaint.model.anytext.ldm.modules.diffusionmodules.model import Encoder, Decoder +from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from iopaint.model.anytext.ldm.util import instantiate_from_config +from iopaint.model.anytext.ldm.modules.ema import LitEma + + +class AutoencoderKL(torch.nn.Module): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/__init__.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/ddim.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000..f8bbaff --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/ddim.py @@ -0,0 +1,354 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + # cbs = len(ctmp[0]) + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img], "index": [10000]} + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + intermediates['index'].append(index) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + elif isinstance(c[k], dict): + c_in[k] = dict() + for key in c[k]: + if isinstance(c[k][key], list): + if not isinstance(c[k][key][0], torch.Tensor): + continue + c_in[k][key] = [torch.cat([ + unconditional_conditioning[k][key][i], + c[k][key][i]]) for i in range(len(c[k][key]))] + else: + c_in[k][key] = torch.cat([ + unconditional_conditioning[k][key], + c[k][key]]) + + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/ddpm.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000..9f48918 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/ddpm.py @@ -0,0 +1,2380 @@ +""" +Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py +""" + +import torch +import torch.nn as nn +import numpy as np +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from omegaconf import ListConfig + +from iopaint.model.anytext.ldm.util import ( + log_txt_as_img, + exists, + default, + ismap, + isimage, + mean_flat, + count_params, + instantiate_from_config, +) +from iopaint.model.anytext.ldm.modules.ema import LitEma +from iopaint.model.anytext.ldm.modules.distributions.distributions import ( + normal_kl, + DiagonalGaussianDistribution, +) +from iopaint.model.anytext.ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) +from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler +import cv2 + + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + +PRINT_DEBUG = False + + +def print_grad(grad): + # print('Gradient:', grad) + # print(grad.shape) + a = grad.max() + b = grad.min() + # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}') + s = 255.0 / (a - b) + c = 255 * (-b / (a - b)) + grad = grad * s + c + # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}') + img = grad[0].permute(1, 2, 0).detach().cpu().numpy() + if img.shape[0] == 512: + cv2.imwrite("grad-img.jpg", img) + elif img.shape[0] == 64: + cv2.imwrite("grad-latent.jpg", img) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(torch.nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + "v", + ], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + print( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if reset_ema: + assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt( + ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet + ) + if reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print( + " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " + ) + assert self.use_ema + self.model_ema.reset_num_updates() + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.register_buffer("logvar", logvar) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + # np.save('1.npy', alphas_cumprod) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like( + self.betas**2 + / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + ) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len( + [ + name + for name, _ in itertools.chain( + self.named_parameters(), self.named_buffers() + ) + ] + ) + for name, param in tqdm( + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[ + i % old_shape[0], j % old_shape[1] + ] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates, + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == "l1": + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == "l2": + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError( + f"Parameterization {self.parameterization} not yet supported" + ) + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = "train" if self.training else "val" + + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f"{log_prefix}/loss": loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, "b h w c -> b c h w") + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict( + loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True + ) + self.log_dict( + loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True + ) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample( + batch_size=N, return_intermediates=True + ) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + force_null_conditioning=False, + *args, + **kwargs, + ): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if ( + cond_stage_config == "__is_unconditional__" + and not self.force_null_conditioning + ): + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + reset_ema = kwargs.pop("reset_ema", False) + reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + if reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print( + " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " + ) + assert self.use_ema + self.model_ema.reset_num_updates() + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert ( + self.scale_factor == 1.0 + ), "rather not use custom rescaling and std-rescaling simultaneously" + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list( + self, samples, desc="", force_no_decoder_quantization=False + ): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append( + self.decode_first_stage( + zd.to(self.device), force_not_quantize=force_no_decoder_quantization + ) + ) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min( + torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 + )[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold( + self, x, kernel_size, stride, uf=1, df=1 + ): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting( + kernel_size[0], kernel_size[1], Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 + ) + + weighting = self.get_weighting( + kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h * uf, w * uf + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) + ) + + elif df > 1 and uf == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2 + ) + + weighting = self.get_weighting( + kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h // df, w // df + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) + ) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False, + mask_k=None, + ): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if mask_k is not None: + mx = super().get_input(batch, mask_k) + if bs is not None: + mx = mx[:bs] + mx = mx.to(self.device) + encoder_posterior = self.encode_first_stage(mx) + mx = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ["caption", "coordinates_bbox", "txt"]: + xc = batch[cond_key] + elif cond_key in ["class_label", "cls"]: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {"pos_x": pos_x, "pos_y": pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + if mask_k: + out.append(mx) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.first_stage_model.decode(z) + + def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = ( + "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" + ) + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample( + ddim_steps, batch_size, shape, cond, verbose=False, **kwargs + ) + + else: + samples, intermediates = self.sample( + cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs + ) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning( + batch_size, device=self.device + ) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device) + else: + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key in ["class_label", "cls"]: + try: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_denoised=True, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask, + ) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1.0 - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask, + ) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising( + c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N, + ) + prog_row = self._get_denoise_row_from_list( + progressives, desc="Progressive Generation" + ) + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print("Diffusion model optimizing logvar") + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert "target" in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class DiffusionWrapper(torch.nn.Module): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop( + "sequential_crossattn", False + ) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, + "concat", + "crossattn", + "hybrid", + "adm", + "hybrid-adm", + "crossattn-adm", + ] + + def forward( + self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None + ): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == "crossattn": + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "hybrid": + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == "hybrid-adm": + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == "crossattn-adm": + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__( + self, + *args, + low_scale_config, + low_scale_key="LR", + noise_level_key=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, "b h w c -> b c h w") + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError("TODO") + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[ + f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}" + ] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key in ["class_label", "cls"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising( + c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N, + ) + prog_row = self._get_denoise_row_from_list( + progressives, desc="Progressive Generation" + ) + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__( + self, + concat_keys: tuple, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs, + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): + assert exists(ckpt_path), "can only finetune from a given checkpoint" + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" + ) + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input( + batch, self.first_stage_key, bs=N, return_first_stage_outputs=True + ) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key in ["class_label", "cls"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage( + c_cat[:, self.c_concat_log_start : self.c_concat_log_end] + ) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, + **kwargs, + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = ( + rearrange(args[0]["masked_image"], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for depth2img" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin( + cc, dim=[1, 2, 3], keepdim=True + ), torch.amax(cc, dim=[1, 2, 3], keepdim=True) + cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin( + depth, dim=[1, 2, 3], keepdim=True + ), torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + + def __init__( + self, + concat_keys=("lr",), + reshuffle_patch_size=None, + low_scale_config=None, + low_scale_key=None, + *args, + **kwargs, + ): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for upscaling-ft" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, "b h w c -> b c h w") + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange( + cc, + "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size, + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") + return log diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000..7427f38 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000..095e5ba --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000..7d137b8 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/plms.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/plms.py new file mode 100644 index 0000000..5f35d55 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/plms.py @@ -0,0 +1,244 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from iopaint.model.anytext.ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/custom-demo/back-end/model/anytext/ldm/models/diffusion/sampling_util.py b/custom-demo/back-end/model/anytext/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000..7eff02b --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ldm/modules/__init__.py b/custom-demo/back-end/model/anytext/ldm/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/modules/attention.py b/custom-demo/back-end/model/anytext/ldm/modules/attention.py new file mode 100644 index 0000000..df92aa7 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/attention.py @@ -0,0 +1,360 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint + + +# CrossAttn precision handling +import os + +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == "fp32": + with torch.autocast(enabled=False, device_type="cuda"): + q, k = q.float(), k.float() + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + else: + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class SDPACrossAttention(CrossAttention): + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + k_in = self.to_k(context) + v_in = self.to_v(context) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + if _ATTN_PRECISION == "fp32": + q, k, v = q.float(), k.float(), v.float() + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, h * head_dim + ) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): + super().__init__() + + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + attn_cls = SDPACrossAttention + else: + attn_cls = CrossAttention + + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = ( + self.attn1( + self.norm1(x), context=context if self.disable_self_attn else None + ) + + x + ) + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + ): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/__init__.py b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/model.py b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000..3472824 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,973 @@ +# pytorch_diffusion + derived encoder decoder +import math + +import numpy as np +import torch +import torch.nn as nn + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock2_0(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + # output: [1, 512, 64, 64] + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + + # q = q.reshape(b, c, h * w).transpose() + # q = q.permute(0, 2, 1) # b,hw,c + # k = k.reshape(b, c, h * w) # b,c,hw + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + # (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.to(q.dtype) + + h_ = self.proj_out(hidden_states) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + assert attn_kwargs is None + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + # print(f"Using torch.nn.functional.scaled_dot_product_attention") + return AttnBlock2_0(in_channels) + return AttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x diff --git a/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/openaimodel.py b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000..fd3d6be --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,786 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer +from iopaint.model.anytext.ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + self.use_fp16 = use_fp16 + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/upscaling.py b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 0000000..5f92630 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from iopaint.model.anytext.ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/util.py b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000..da29c72 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from iopaint.model.anytext.ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + # return super().forward(x.float()).type(x.dtype) + return super().forward(x).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ldm/modules/distributions/__init__.py b/custom-demo/back-end/model/anytext/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/modules/distributions/distributions.py b/custom-demo/back-end/model/anytext/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000..f2b8ef9 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/custom-demo/back-end/model/anytext/ldm/modules/ema.py b/custom-demo/back-end/model/anytext/ldm/modules/ema.py new file mode 100644 index 0000000..bded250 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/custom-demo/back-end/model/anytext/ldm/modules/encoders/__init__.py b/custom-demo/back-end/model/anytext/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ldm/modules/encoders/modules.py b/custom-demo/back-end/model/anytext/ldm/modules/encoders/modules.py new file mode 100644 index 0000000..ceac395 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/modules/encoders/modules.py @@ -0,0 +1,411 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import ( + T5Tokenizer, + T5EncoderModel, + CLIPTokenizer, + CLIPTextModel, + AutoProcessor, + CLIPVisionModelWithProjection, +) + +from iopaint.model.anytext.ldm.util import count_params + + +def _expand_mask(mask, dtype, tgt_len=None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +def _build_causal_attention_mask(bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + +class FrozenCLIPEmbedderT3(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + use_vision=False, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + if use_vision: + self.vit = CLIPVisionModelWithProjection.from_pretrained(version) + self.processor = AutoProcessor.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def embedding_forward( + self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + embedding_manager=None, + ): + seq_length = ( + input_ids.shape[-1] + if input_ids is not None + else inputs_embeds.shape[-2] + ) + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__( + self.transformer.text_model.embeddings + ) + + def encoder_forward( + self, + inputs_embeds, + attention_mask=None, + causal_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__( + self.transformer.text_model.encoder + ) + + def text_encoder_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if input_ids is None: + raise ValueError("You have to specify either input_ids") + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding_manager=embedding_manager, + ) + bsz, seq_len = input_shape + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _build_causal_attention_mask( + bsz, seq_len, hidden_states.dtype + ).to(hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__( + self.transformer.text_model + ) + + def transformer_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager=embedding_manager, + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, **kwargs): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + z = self.transformer(input_ids=tokens, **kwargs) + return z + + def encode(self, text, **kwargs): + return self(text, **kwargs) diff --git a/custom-demo/back-end/model/anytext/ldm/util.py b/custom-demo/back-end/model/anytext/ldm/util.py new file mode 100644 index 0000000..d456a86 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ldm/util.py @@ -0,0 +1,197 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size) + nc = int(32 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config, **kwargs): + if "target" not in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/main.py b/custom-demo/back-end/model/anytext/main.py new file mode 100644 index 0000000..f7b2d2e --- /dev/null +++ b/custom-demo/back-end/model/anytext/main.py @@ -0,0 +1,45 @@ +import cv2 +import os + +from anytext_pipeline import AnyTextPipeline +from utils import save_images + +seed = 66273235 +# seed_everything(seed) + +pipe = AnyTextPipeline( + ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt", + font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf", + use_fp16=False, + device="mps", +) + +img_save_folder = "SaveImages" +rgb_image = cv2.imread( + "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg" +)[..., ::-1] + +masked_image = cv2.imread( + "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png" +)[..., ::-1] + +rgb_image = cv2.resize(rgb_image, (512, 512)) +masked_image = cv2.resize(masked_image, (512, 512)) + +# results: list of rgb ndarray +results, rtn_code, rtn_warning = pipe( + prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks', + negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture", + image=rgb_image, + masked_image=masked_image, + num_inference_steps=20, + strength=1.0, + guidance_scale=9.0, + height=rgb_image.shape[0], + width=rgb_image.shape[1], + seed=seed, + sort_priority="y", +) +if rtn_code >= 0: + save_images(results, img_save_folder) + print(f"Done, result images are saved in: {img_save_folder}") diff --git a/custom-demo/back-end/model/anytext/ocr_recog/RNN.py b/custom-demo/back-end/model/anytext/ocr_recog/RNN.py new file mode 100755 index 0000000..cf16855 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/RNN.py @@ -0,0 +1,210 @@ +from torch import nn +import torch +from .RecSVTR import Block + +class Swish(nn.Module): + def __int__(self): + super(Swish, self).__int__() + + def forward(self,x): + return x*torch.sigmoid(x) + +class Im2Im(nn.Module): + def __init__(self, in_channels, **kwargs): + super().__init__() + self.out_channels = in_channels + + def forward(self, x): + return x + +class Im2Seq(nn.Module): + def __init__(self, in_channels, **kwargs): + super().__init__() + self.out_channels = in_channels + + def forward(self, x): + B, C, H, W = x.shape + # assert H == 1 + x = x.reshape(B, C, H * W) + x = x.permute((0, 2, 1)) + return x + +class EncoderWithRNN(nn.Module): + def __init__(self, in_channels,**kwargs): + super(EncoderWithRNN, self).__init__() + hidden_size = kwargs.get('hidden_size', 256) + self.out_channels = hidden_size * 2 + self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True) + + def forward(self, x): + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + return x + +class SequenceEncoder(nn.Module): + def __init__(self, in_channels, encoder_type='rnn', **kwargs): + super(SequenceEncoder, self).__init__() + self.encoder_reshape = Im2Seq(in_channels) + self.out_channels = self.encoder_reshape.out_channels + self.encoder_type = encoder_type + if encoder_type == 'reshape': + self.only_reshape = True + else: + support_encoder_dict = { + 'reshape': Im2Seq, + 'rnn': EncoderWithRNN, + 'svtr': EncoderWithSVTR + } + assert encoder_type in support_encoder_dict, '{} must in {}'.format( + encoder_type, support_encoder_dict.keys()) + + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels,**kwargs) + self.out_channels = self.encoder.out_channels + self.only_reshape = False + + def forward(self, x): + if self.encoder_type != 'svtr': + x = self.encoder_reshape(x) + if not self.only_reshape: + x = self.encoder(x) + return x + else: + x = self.encoder(x) + x = self.encoder_reshape(x) + return x + +class ConvBNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias_attr=False, + groups=1, + act=nn.GELU): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), + bias=bias_attr) + self.norm = nn.BatchNorm2d(out_channels) + self.act = Swish() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class EncoderWithSVTR(nn.Module): + def __init__( + self, + in_channels, + dims=64, # XS + depth=2, + hidden_dims=120, + use_guide=False, + num_heads=8, + qkv_bias=True, + mlp_ratio=2.0, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path=0., + qk_scale=None): + super(EncoderWithSVTR, self).__init__() + self.depth = depth + self.use_guide = use_guide + self.conv1 = ConvBNLayer( + in_channels, in_channels // 8, padding=1, act='swish') + self.conv2 = ConvBNLayer( + in_channels // 8, hidden_dims, kernel_size=1, act='swish') + + self.svtr_block = nn.ModuleList([ + Block( + dim=hidden_dims, + num_heads=num_heads, + mixer='Global', + HW=None, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer='swish', + attn_drop=attn_drop_rate, + drop_path=drop_path, + norm_layer='nn.LayerNorm', + epsilon=1e-05, + prenorm=False) for i in range(depth) + ]) + self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) + self.conv3 = ConvBNLayer( + hidden_dims, in_channels, kernel_size=1, act='swish') + # last conv-nxn, the input is concat of input tensor and conv3 output tensor + self.conv4 = ConvBNLayer( + 2 * in_channels, in_channels // 8, padding=1, act='swish') + + self.conv1x1 = ConvBNLayer( + in_channels // 8, dims, kernel_size=1, act='swish') + self.out_channels = dims + self.apply(self._init_weights) + + def _init_weights(self, m): + # weight initialization + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + # for use guide + if self.use_guide: + z = x.clone() + z.stop_gradient = True + else: + z = x + # for short cut + h = z + # reduce dim + z = self.conv1(z) + z = self.conv2(z) + # SVTR global block + B, C, H, W = z.shape + z = z.flatten(2).permute(0, 2, 1) + + for blk in self.svtr_block: + z = blk(z) + + z = self.norm(z) + # last stage + z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) + z = self.conv3(z) + z = torch.cat((h, z), dim=1) + z = self.conv1x1(self.conv4(z)) + + return z + +if __name__=="__main__": + svtrRNN = EncoderWithSVTR(56) + print(svtrRNN) \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ocr_recog/RecCTCHead.py b/custom-demo/back-end/model/anytext/ocr_recog/RecCTCHead.py new file mode 100755 index 0000000..867ede9 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/RecCTCHead.py @@ -0,0 +1,48 @@ +from torch import nn + + +class CTCHead(nn.Module): + def __init__(self, + in_channels, + out_channels=6625, + fc_decay=0.0004, + mid_channels=None, + return_feats=False, + **kwargs): + super(CTCHead, self).__init__() + if mid_channels is None: + self.fc = nn.Linear( + in_channels, + out_channels, + bias=True,) + else: + self.fc1 = nn.Linear( + in_channels, + mid_channels, + bias=True, + ) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + bias=True, + ) + + self.out_channels = out_channels + self.mid_channels = mid_channels + self.return_feats = return_feats + + def forward(self, x, labels=None): + if self.mid_channels is None: + predicts = self.fc(x) + else: + x = self.fc1(x) + predicts = self.fc2(x) + + if self.return_feats: + result = dict() + result['ctc'] = predicts + result['ctc_neck'] = x + else: + result = predicts + + return result diff --git a/custom-demo/back-end/model/anytext/ocr_recog/RecModel.py b/custom-demo/back-end/model/anytext/ocr_recog/RecModel.py new file mode 100755 index 0000000..c2313bf --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/RecModel.py @@ -0,0 +1,45 @@ +from torch import nn +from .RNN import SequenceEncoder, Im2Seq, Im2Im +from .RecMv1_enhance import MobileNetV1Enhance + +from .RecCTCHead import CTCHead + +backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance} +neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im} +head_dict = {'CTCHead':CTCHead} + + +class RecModel(nn.Module): + def __init__(self, config): + super().__init__() + assert 'in_channels' in config, 'in_channels must in model config' + backbone_type = config.backbone.pop('type') + assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' + self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) + + neck_type = config.neck.pop('type') + assert neck_type in neck_dict, f'neck.type must in {neck_dict}' + self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) + + head_type = config.head.pop('type') + assert head_type in head_dict, f'head.type must in {head_dict}' + self.head = head_dict[head_type](self.neck.out_channels, **config.head) + + self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}' + + def load_3rd_state_dict(self, _3rd_name, _state): + self.backbone.load_3rd_state_dict(_3rd_name, _state) + self.neck.load_3rd_state_dict(_3rd_name, _state) + self.head.load_3rd_state_dict(_3rd_name, _state) + + def forward(self, x): + x = self.backbone(x) + x = self.neck(x) + x = self.head(x) + return x + + def encode(self, x): + x = self.backbone(x) + x = self.neck(x) + x = self.head.ctc_encoder(x) + return x diff --git a/custom-demo/back-end/model/anytext/ocr_recog/RecMv1_enhance.py b/custom-demo/back-end/model/anytext/ocr_recog/RecMv1_enhance.py new file mode 100644 index 0000000..7529b4a --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/RecMv1_enhance.py @@ -0,0 +1,232 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .common import Activation + + +class ConvBNLayer(nn.Module): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='hard_swish'): + super(ConvBNLayer, self).__init__() + self.act = act + self._conv = nn.Conv2d( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + bias=False) + + self._batch_norm = nn.BatchNorm2d( + num_filters, + ) + if self.act is not None: + self._act = Activation(act_type=act, inplace=True) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + if self.act is not None: + y = self._act(y) + return y + + +class DepthwiseSeparable(nn.Module): + def __init__(self, + num_channels, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + dw_size=3, + padding=1, + use_se=False): + super(DepthwiseSeparable, self).__init__() + self.use_se = use_se + self._depthwise_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=int(num_filters1 * scale), + filter_size=dw_size, + stride=stride, + padding=padding, + num_groups=int(num_groups * scale)) + if use_se: + self._se = SEModule(int(num_filters1 * scale)) + self._pointwise_conv = ConvBNLayer( + num_channels=int(num_filters1 * scale), + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + if self.use_se: + y = self._se(y) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1Enhance(nn.Module): + def __init__(self, + in_channels=3, + scale=0.5, + last_conv_stride=1, + last_pool_type='max', + **kwargs): + super().__init__() + self.scale = scale + self.block_list = [] + + self.conv1 = ConvBNLayer( + num_channels=in_channels, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + conv2_1 = DepthwiseSeparable( + num_channels=int(32 * scale), + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale) + self.block_list.append(conv2_1) + + conv2_2 = DepthwiseSeparable( + num_channels=int(64 * scale), + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=1, + scale=scale) + self.block_list.append(conv2_2) + + conv3_1 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale) + self.block_list.append(conv3_1) + + conv3_2 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=(2, 1), + scale=scale) + self.block_list.append(conv3_2) + + conv4_1 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale) + self.block_list.append(conv4_1) + + conv4_2 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=(2, 1), + scale=scale) + self.block_list.append(conv4_2) + + for _ in range(5): + conv5 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + dw_size=5, + padding=2, + scale=scale, + use_se=False) + self.block_list.append(conv5) + + conv5_6 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=(2, 1), + dw_size=5, + padding=2, + scale=scale, + use_se=True) + self.block_list.append(conv5_6) + + conv6 = DepthwiseSeparable( + num_channels=int(1024 * scale), + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=last_conv_stride, + dw_size=5, + padding=2, + use_se=True, + scale=scale) + self.block_list.append(conv6) + + self.block_list = nn.Sequential(*self.block_list) + if last_pool_type == 'avg': + self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.out_channels = int(1024 * scale) + + def forward(self, inputs): + y = self.conv1(inputs) + y = self.block_list(y) + y = self.pool(y) + return y + +def hardsigmoid(x): + return F.relu6(x + 3., inplace=True) / 6. + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.conv2 = nn.Conv2d( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = hardsigmoid(outputs) + x = torch.mul(inputs, outputs) + + return x diff --git a/custom-demo/back-end/model/anytext/ocr_recog/RecSVTR.py b/custom-demo/back-end/model/anytext/ocr_recog/RecSVTR.py new file mode 100644 index 0000000..484b3df --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/RecSVTR.py @@ -0,0 +1,591 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.nn.init import trunc_normal_, zeros_, ones_ +from torch.nn import functional + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = torch.tensor(1 - drop_prob) + shape = (x.size()[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype) + random_tensor = torch.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class Swish(nn.Module): + def __int__(self): + super(Swish, self).__int__() + + def forward(self,x): + return x*torch.sigmoid(x) + + +class ConvBNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias_attr=False, + groups=1, + act=nn.GELU): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), + bias=bias_attr) + self.norm = nn.BatchNorm2d(out_channels) + self.act = act() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + if isinstance(act_layer, str): + self.act = Swish() + else: + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMixer(nn.Module): + def __init__( + self, + dim, + num_heads=8, + HW=(8, 25), + local_k=(3, 3), ): + super().__init__() + self.HW = HW + self.dim = dim + self.local_mixer = nn.Conv2d( + dim, + dim, + local_k, + 1, (local_k[0] // 2, local_k[1] // 2), + groups=num_heads, + # weight_attr=ParamAttr(initializer=KaimingNormal()) + ) + + def forward(self, x): + h = self.HW[0] + w = self.HW[1] + x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).transpose([0, 2, 1]) + return x + + +class Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + mixer='Global', + HW=(8, 25), + local_k=(7, 11), + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.HW = HW + if HW is not None: + H = HW[0] + W = HW[1] + self.N = H * W + self.C = dim + if mixer == 'Local' and HW is not None: + hk = local_k[0] + wk = local_k[1] + mask = torch.ones([H * W, H + hk - 1, W + wk - 1]) + for h in range(0, H): + for w in range(0, W): + mask[h * W + w, h:h + hk, w:w + wk] = 0. + mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // + 2].flatten(1) + mask_inf = torch.full([H * W, H * W],fill_value=float('-inf')) + mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) + self.mask = mask[None,None,:] + # self.mask = mask.unsqueeze([0, 1]) + self.mixer = mixer + + def forward(self, x): + if self.HW is not None: + N = self.N + C = self.C + else: + _, N, C = x.shape + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = (q.matmul(k.permute((0, 1, 3, 2)))) + if self.mixer == 'Local': + attn += self.mask + attn = functional.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, + dim, + num_heads, + mixer='Global', + local_mixer=(7, 11), + HW=(8, 25), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-6, + prenorm=True): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, eps=epsilon) + else: + self.norm1 = norm_layer(dim) + if mixer == 'Global' or mixer == 'Local': + + self.mixer = Attention( + dim, + num_heads=num_heads, + mixer=mixer, + HW=HW, + local_k=local_mixer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + elif mixer == 'Conv': + self.mixer = ConvMixer( + dim, num_heads=num_heads, HW=HW, local_k=local_mixer) + else: + raise TypeError("The mixer must be one of [Global, Local, Conv]") + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, eps=epsilon) + else: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.prenorm = prenorm + + def forward(self, x): + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=(32, 100), + in_channels=3, + embed_dim=768, + sub_num=2): + super().__init__() + num_patches = (img_size[1] // (2 ** sub_num)) * \ + (img_size[0] // (2 ** sub_num)) + self.img_size = img_size + self.num_patches = num_patches + self.embed_dim = embed_dim + self.norm = None + if sub_num == 2: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False)) + if sub_num == 3: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 4, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False), + ConvBNLayer( + in_channels=embed_dim // 4, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False)) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).permute(0, 2, 1) + return x + + +class SubSample(nn.Module): + def __init__(self, + in_channels, + out_channels, + types='Pool', + stride=(2, 1), + sub_norm='nn.LayerNorm', + act=None): + super().__init__() + self.types = types + if types == 'Pool': + self.avgpool = nn.AvgPool2d( + kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.maxpool = nn.MaxPool2d( + kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + # weight_attr=ParamAttr(initializer=KaimingNormal()) + ) + self.norm = eval(sub_norm)(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x): + + if self.types == 'Pool': + x1 = self.avgpool(x) + x2 = self.maxpool(x) + x = (x1 + x2) * 0.5 + out = self.proj(x.flatten(2).permute((0, 2, 1))) + else: + x = self.conv(x) + out = x.flatten(2).permute((0, 2, 1)) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + + return out + + +class SVTRNet(nn.Module): + def __init__( + self, + img_size=[48, 100], + in_channels=3, + embed_dim=[64, 128, 256], + depth=[3, 6, 3], + num_heads=[2, 4, 8], + mixer=['Local'] * 6 + ['Global'] * + 6, # Local atten, Global atten, Conv + local_mixer=[[7, 11], [7, 11], [7, 11]], + patch_merging='Conv', # Conv, Pool, None + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + last_drop=0.1, + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer='nn.LayerNorm', + sub_norm='nn.LayerNorm', + epsilon=1e-6, + out_channels=192, + out_char_num=25, + block_unit='Block', + act='nn.GELU', + last_stage=True, + sub_num=2, + prenorm=True, + use_lenhead=False, + **kwargs): + super().__init__() + self.img_size = img_size + self.embed_dim = embed_dim + self.out_channels = out_channels + self.prenorm = prenorm + patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging + self.patch_embed = PatchEmbed( + img_size=img_size, + in_channels=in_channels, + embed_dim=embed_dim[0], + sub_num=sub_num) + num_patches = self.patch_embed.num_patches + self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) + # self.pos_embed = self.create_parameter( + # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) + + # self.add_parameter("pos_embed", self.pos_embed) + + self.pos_drop = nn.Dropout(p=drop_rate) + Block_unit = eval(block_unit) + + dpr = np.linspace(0, drop_path_rate, sum(depth)) + self.blocks1 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[0], + num_heads=num_heads[0], + mixer=mixer[0:depth[0]][i], + HW=self.HW, + local_mixer=local_mixer[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[0:depth[0]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[0]) + ] + ) + if patch_merging is not None: + self.sub_sample1 = SubSample( + embed_dim[0], + embed_dim[1], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 2, self.HW[1]] + else: + HW = self.HW + self.patch_merging = patch_merging + self.blocks2 = nn.ModuleList([ + Block_unit( + dim=embed_dim[1], + num_heads=num_heads[1], + mixer=mixer[depth[0]:depth[0] + depth[1]][i], + HW=HW, + local_mixer=local_mixer[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0]:depth[0] + depth[1]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[1]) + ]) + if patch_merging is not None: + self.sub_sample2 = SubSample( + embed_dim[1], + embed_dim[2], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 4, self.HW[1]] + else: + HW = self.HW + self.blocks3 = nn.ModuleList([ + Block_unit( + dim=embed_dim[2], + num_heads=num_heads[2], + mixer=mixer[depth[0] + depth[1]:][i], + HW=HW, + local_mixer=local_mixer[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1]:][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[2]) + ]) + self.last_stage = last_stage + if last_stage: + self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num)) + self.last_conv = nn.Conv2d( + in_channels=embed_dim[2], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.hardswish = nn.Hardswish() + self.dropout = nn.Dropout(p=last_drop) + if not prenorm: + self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) + self.use_lenhead = use_lenhead + if use_lenhead: + self.len_conv = nn.Linear(embed_dim[2], self.out_channels) + self.hardswish_len = nn.Hardswish() + self.dropout_len = nn.Dropout( + p=last_drop) + + trunc_normal_(self.pos_embed,std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight,std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample1( + x.permute([0, 2, 1]).reshape( + [-1, self.embed_dim[0], self.HW[0], self.HW[1]])) + for blk in self.blocks2: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample2( + x.permute([0, 2, 1]).reshape( + [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) + for blk in self.blocks3: + x = blk(x) + if not self.prenorm: + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.use_lenhead: + len_x = self.len_conv(x.mean(1)) + len_x = self.dropout_len(self.hardswish_len(len_x)) + if self.last_stage: + if self.patch_merging is not None: + h = self.HW[0] // 4 + else: + h = self.HW[0] + x = self.avg_pool( + x.permute([0, 2, 1]).reshape( + [-1, self.embed_dim[2], h, self.HW[1]])) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + if self.use_lenhead: + return x, len_x + return x + + +if __name__=="__main__": + a = torch.rand(1,3,48,100) + svtr = SVTRNet() + + out = svtr(a) + print(svtr) + print(out.size()) \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ocr_recog/__init__.py b/custom-demo/back-end/model/anytext/ocr_recog/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/anytext/ocr_recog/common.py b/custom-demo/back-end/model/anytext/ocr_recog/common.py new file mode 100644 index 0000000..a328bb0 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/common.py @@ -0,0 +1,74 @@ + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3., inplace=self.inplace) / 6. + +# out = max(0, min(1, slop*x+offset)) +# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + # torch: F.relu6(x + 3., inplace=self.inplace) / 6. + # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. + return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. + +class GELU(nn.Module): + def __init__(self, inplace=True): + super(GELU, self).__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.gelu(x) + + +class Swish(nn.Module): + def __init__(self, inplace=True): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + x.mul_(torch.sigmoid(x)) + return x + else: + return x*torch.sigmoid(x) + + +class Activation(nn.Module): + def __init__(self, act_type, inplace=True): + super(Activation, self).__init__() + act_type = act_type.lower() + if act_type == 'relu': + self.act = nn.ReLU(inplace=inplace) + elif act_type == 'relu6': + self.act = nn.ReLU6(inplace=inplace) + elif act_type == 'sigmoid': + raise NotImplementedError + elif act_type == 'hard_sigmoid': + self.act = Hsigmoid(inplace) + elif act_type == 'hard_swish': + self.act = Hswish(inplace=inplace) + elif act_type == 'leakyrelu': + self.act = nn.LeakyReLU(inplace=inplace) + elif act_type == 'gelu': + self.act = GELU(inplace=inplace) + elif act_type == 'swish': + self.act = Swish(inplace=inplace) + else: + raise NotImplementedError + + def forward(self, inputs): + return self.act(inputs) \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/ocr_recog/en_dict.txt b/custom-demo/back-end/model/anytext/ocr_recog/en_dict.txt new file mode 100644 index 0000000..7677d31 --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/en_dict.txt @@ -0,0 +1,95 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ + diff --git a/custom-demo/back-end/model/anytext/ocr_recog/ppocr_keys_v1.txt b/custom-demo/back-end/model/anytext/ocr_recog/ppocr_keys_v1.txt new file mode 100644 index 0000000..84b885d --- /dev/null +++ b/custom-demo/back-end/model/anytext/ocr_recog/ppocr_keys_v1.txt @@ -0,0 +1,6623 @@ +' +疗 +绚 +诚 +娇 +溜 +题 +贿 +者 +廖 +更 +纳 +加 +奉 +公 +一 +就 +汴 +计 +与 +路 +房 +原 +妇 +2 +0 +8 +- +7 +其 +> +: +] +, +, +骑 +刈 +全 +消 +昏 +傈 +安 +久 +钟 +嗅 +不 +影 +处 +驽 +蜿 +资 +关 +椤 +地 +瘸 +专 +问 +忖 +票 +嫉 +炎 +韵 +要 +月 +田 +节 +陂 +鄙 +捌 +备 +拳 +伺 +眼 +网 +盎 +大 +傍 +心 +东 +愉 +汇 +蹿 +科 +每 +业 +里 +航 +晏 +字 +平 +录 +先 +1 +3 +彤 +鲶 +产 +稍 +督 +腴 +有 +象 +岳 +注 +绍 +在 +泺 +文 +定 +核 +名 +水 +过 +理 +让 +偷 +率 +等 +这 +发 +” +为 +含 +肥 +酉 +相 +鄱 +七 +编 +猥 +锛 +日 +镀 +蒂 +掰 +倒 +辆 +栾 +栗 +综 +涩 +州 +雌 +滑 +馀 +了 +机 +块 +司 +宰 +甙 +兴 +矽 +抚 +保 +用 +沧 +秩 +如 +收 +息 +滥 +页 +疑 +埠 +! +! +姥 +异 +橹 +钇 +向 +下 +跄 +的 +椴 +沫 +国 +绥 +獠 +报 +开 +民 +蜇 +何 +分 +凇 +长 +讥 +藏 +掏 +施 +羽 +中 +讲 +派 +嘟 +人 +提 +浼 +间 +世 +而 +古 +多 +倪 +唇 +饯 +控 +庚 +首 +赛 +蜓 +味 +断 +制 +觉 +技 +替 +艰 +溢 +潮 +夕 +钺 +外 +摘 +枋 +动 +双 +单 +啮 +户 +枇 +确 +锦 +曜 +杜 +或 +能 +效 +霜 +盒 +然 +侗 +电 +晁 +放 +步 +鹃 +新 +杖 +蜂 +吒 +濂 +瞬 +评 +总 +隍 +对 +独 +合 +也 +是 +府 +青 +天 +诲 +墙 +组 +滴 +级 +邀 +帘 +示 +已 +时 +骸 +仄 +泅 +和 +遨 +店 +雇 +疫 +持 +巍 +踮 +境 +只 +亨 +目 +鉴 +崤 +闲 +体 +泄 +杂 +作 +般 +轰 +化 +解 +迂 +诿 +蛭 +璀 +腾 +告 +版 +服 +省 +师 +小 +规 +程 +线 +海 +办 +引 +二 +桧 +牌 +砺 +洄 +裴 +修 +图 +痫 +胡 +许 +犊 +事 +郛 +基 +柴 +呼 +食 +研 +奶 +律 +蛋 +因 +葆 +察 +戏 +褒 +戒 +再 +李 +骁 +工 +貂 +油 +鹅 +章 +啄 +休 +场 +给 +睡 +纷 +豆 +器 +捎 +说 +敏 +学 +会 +浒 +设 +诊 +格 +廓 +查 +来 +霓 +室 +溆 +¢ +诡 +寥 +焕 +舜 +柒 +狐 +回 +戟 +砾 +厄 +实 +翩 +尿 +五 +入 +径 +惭 +喹 +股 +宇 +篝 +| +; +美 +期 +云 +九 +祺 +扮 +靠 +锝 +槌 +系 +企 +酰 +阊 +暂 +蚕 +忻 +豁 +本 +羹 +执 +条 +钦 +H +獒 +限 +进 +季 +楦 +于 +芘 +玖 +铋 +茯 +未 +答 +粘 +括 +样 +精 +欠 +矢 +甥 +帷 +嵩 +扣 +令 +仔 +风 +皈 +行 +支 +部 +蓉 +刮 +站 +蜡 +救 +钊 +汗 +松 +嫌 +成 +可 +. +鹤 +院 +从 +交 +政 +怕 +活 +调 +球 +局 +验 +髌 +第 +韫 +谗 +串 +到 +圆 +年 +米 +/ +* +友 +忿 +检 +区 +看 +自 +敢 +刃 +个 +兹 +弄 +流 +留 +同 +没 +齿 +星 +聆 +轼 +湖 +什 +三 +建 +蛔 +儿 +椋 +汕 +震 +颧 +鲤 +跟 +力 +情 +璺 +铨 +陪 +务 +指 +族 +训 +滦 +鄣 +濮 +扒 +商 +箱 +十 +召 +慷 +辗 +所 +莞 +管 +护 +臭 +横 +硒 +嗓 +接 +侦 +六 +露 +党 +馋 +驾 +剖 +高 +侬 +妪 +幂 +猗 +绺 +骐 +央 +酐 +孝 +筝 +课 +徇 +缰 +门 +男 +西 +项 +句 +谙 +瞒 +秃 +篇 +教 +碲 +罚 +声 +呐 +景 +前 +富 +嘴 +鳌 +稀 +免 +朋 +啬 +睐 +去 +赈 +鱼 +住 +肩 +愕 +速 +旁 +波 +厅 +健 +茼 +厥 +鲟 +谅 +投 +攸 +炔 +数 +方 +击 +呋 +谈 +绩 +别 +愫 +僚 +躬 +鹧 +胪 +炳 +招 +喇 +膨 +泵 +蹦 +毛 +结 +5 +4 +谱 +识 +陕 +粽 +婚 +拟 +构 +且 +搜 +任 +潘 +比 +郢 +妨 +醪 +陀 +桔 +碘 +扎 +选 +哈 +骷 +楷 +亿 +明 +缆 +脯 +监 +睫 +逻 +婵 +共 +赴 +淝 +凡 +惦 +及 +达 +揖 +谩 +澹 +减 +焰 +蛹 +番 +祁 +柏 +员 +禄 +怡 +峤 +龙 +白 +叽 +生 +闯 +起 +细 +装 +谕 +竟 +聚 +钙 +上 +导 +渊 +按 +艾 +辘 +挡 +耒 +盹 +饪 +臀 +记 +邮 +蕙 +受 +各 +医 +搂 +普 +滇 +朗 +茸 +带 +翻 +酚 +( +光 +堤 +墟 +蔷 +万 +幻 +〓 +瑙 +辈 +昧 +盏 +亘 +蛀 +吉 +铰 +请 +子 +假 +闻 +税 +井 +诩 +哨 +嫂 +好 +面 +琐 +校 +馊 +鬣 +缂 +营 +访 +炖 +占 +农 +缀 +否 +经 +钚 +棵 +趟 +张 +亟 +吏 +茶 +谨 +捻 +论 +迸 +堂 +玉 +信 +吧 +瞠 +乡 +姬 +寺 +咬 +溏 +苄 +皿 +意 +赉 +宝 +尔 +钰 +艺 +特 +唳 +踉 +都 +荣 +倚 +登 +荐 +丧 +奇 +涵 +批 +炭 +近 +符 +傩 +感 +道 +着 +菊 +虹 +仲 +众 +懈 +濯 +颞 +眺 +南 +释 +北 +缝 +标 +既 +茗 +整 +撼 +迤 +贲 +挎 +耱 +拒 +某 +妍 +卫 +哇 +英 +矶 +藩 +治 +他 +元 +领 +膜 +遮 +穗 +蛾 +飞 +荒 +棺 +劫 +么 +市 +火 +温 +拈 +棚 +洼 +转 +果 +奕 +卸 +迪 +伸 +泳 +斗 +邡 +侄 +涨 +屯 +萋 +胭 +氡 +崮 +枞 +惧 +冒 +彩 +斜 +手 +豚 +随 +旭 +淑 +妞 +形 +菌 +吲 +沱 +争 +驯 +歹 +挟 +兆 +柱 +传 +至 +包 +内 +响 +临 +红 +功 +弩 +衡 +寂 +禁 +老 +棍 +耆 +渍 +织 +害 +氵 +渑 +布 +载 +靥 +嗬 +虽 +苹 +咨 +娄 +库 +雉 +榜 +帜 +嘲 +套 +瑚 +亲 +簸 +欧 +边 +6 +腿 +旮 +抛 +吹 +瞳 +得 +镓 +梗 +厨 +继 +漾 +愣 +憨 +士 +策 +窑 +抑 +躯 +襟 +脏 +参 +贸 +言 +干 +绸 +鳄 +穷 +藜 +音 +折 +详 +) +举 +悍 +甸 +癌 +黎 +谴 +死 +罩 +迁 +寒 +驷 +袖 +媒 +蒋 +掘 +模 +纠 +恣 +观 +祖 +蛆 +碍 +位 +稿 +主 +澧 +跌 +筏 +京 +锏 +帝 +贴 +证 +糠 +才 +黄 +鲸 +略 +炯 +饱 +四 +出 +园 +犀 +牧 +容 +汉 +杆 +浈 +汰 +瑷 +造 +虫 +瘩 +怪 +驴 +济 +应 +花 +沣 +谔 +夙 +旅 +价 +矿 +以 +考 +s +u +呦 +晒 +巡 +茅 +准 +肟 +瓴 +詹 +仟 +褂 +译 +桌 +混 +宁 +怦 +郑 +抿 +些 +余 +鄂 +饴 +攒 +珑 +群 +阖 +岔 +琨 +藓 +预 +环 +洮 +岌 +宀 +杲 +瀵 +最 +常 +囡 +周 +踊 +女 +鼓 +袭 +喉 +简 +范 +薯 +遐 +疏 +粱 +黜 +禧 +法 +箔 +斤 +遥 +汝 +奥 +直 +贞 +撑 +置 +绱 +集 +她 +馅 +逗 +钧 +橱 +魉 +[ +恙 +躁 +唤 +9 +旺 +膘 +待 +脾 +惫 +购 +吗 +依 +盲 +度 +瘿 +蠖 +俾 +之 +镗 +拇 +鲵 +厝 +簧 +续 +款 +展 +啃 +表 +剔 +品 +钻 +腭 +损 +清 +锶 +统 +涌 +寸 +滨 +贪 +链 +吠 +冈 +伎 +迥 +咏 +吁 +览 +防 +迅 +失 +汾 +阔 +逵 +绀 +蔑 +列 +川 +凭 +努 +熨 +揪 +利 +俱 +绉 +抢 +鸨 +我 +即 +责 +膦 +易 +毓 +鹊 +刹 +玷 +岿 +空 +嘞 +绊 +排 +术 +估 +锷 +违 +们 +苟 +铜 +播 +肘 +件 +烫 +审 +鲂 +广 +像 +铌 +惰 +铟 +巳 +胍 +鲍 +康 +憧 +色 +恢 +想 +拷 +尤 +疳 +知 +S +Y +F +D +A +峄 +裕 +帮 +握 +搔 +氐 +氘 +难 +墒 +沮 +雨 +叁 +缥 +悴 +藐 +湫 +娟 +苑 +稠 +颛 +簇 +后 +阕 +闭 +蕤 +缚 +怎 +佞 +码 +嘤 +蔡 +痊 +舱 +螯 +帕 +赫 +昵 +升 +烬 +岫 +、 +疵 +蜻 +髁 +蕨 +隶 +烛 +械 +丑 +盂 +梁 +强 +鲛 +由 +拘 +揉 +劭 +龟 +撤 +钩 +呕 +孛 +费 +妻 +漂 +求 +阑 +崖 +秤 +甘 +通 +深 +补 +赃 +坎 +床 +啪 +承 +吼 +量 +暇 +钼 +烨 +阂 +擎 +脱 +逮 +称 +P +神 +属 +矗 +华 +届 +狍 +葑 +汹 +育 +患 +窒 +蛰 +佼 +静 +槎 +运 +鳗 +庆 +逝 +曼 +疱 +克 +代 +官 +此 +麸 +耧 +蚌 +晟 +例 +础 +榛 +副 +测 +唰 +缢 +迹 +灬 +霁 +身 +岁 +赭 +扛 +又 +菡 +乜 +雾 +板 +读 +陷 +徉 +贯 +郁 +虑 +变 +钓 +菜 +圾 +现 +琢 +式 +乐 +维 +渔 +浜 +左 +吾 +脑 +钡 +警 +T +啵 +拴 +偌 +漱 +湿 +硕 +止 +骼 +魄 +积 +燥 +联 +踢 +玛 +则 +窿 +见 +振 +畿 +送 +班 +钽 +您 +赵 +刨 +印 +讨 +踝 +籍 +谡 +舌 +崧 +汽 +蔽 +沪 +酥 +绒 +怖 +财 +帖 +肱 +私 +莎 +勋 +羔 +霸 +励 +哼 +帐 +将 +帅 +渠 +纪 +婴 +娩 +岭 +厘 +滕 +吻 +伤 +坝 +冠 +戊 +隆 +瘁 +介 +涧 +物 +黍 +并 +姗 +奢 +蹑 +掣 +垸 +锴 +命 +箍 +捉 +病 +辖 +琰 +眭 +迩 +艘 +绌 +繁 +寅 +若 +毋 +思 +诉 +类 +诈 +燮 +轲 +酮 +狂 +重 +反 +职 +筱 +县 +委 +磕 +绣 +奖 +晋 +濉 +志 +徽 +肠 +呈 +獐 +坻 +口 +片 +碰 +几 +村 +柿 +劳 +料 +获 +亩 +惕 +晕 +厌 +号 +罢 +池 +正 +鏖 +煨 +家 +棕 +复 +尝 +懋 +蜥 +锅 +岛 +扰 +队 +坠 +瘾 +钬 +@ +卧 +疣 +镇 +譬 +冰 +彷 +频 +黯 +据 +垄 +采 +八 +缪 +瘫 +型 +熹 +砰 +楠 +襁 +箐 +但 +嘶 +绳 +啤 +拍 +盥 +穆 +傲 +洗 +盯 +塘 +怔 +筛 +丿 +台 +恒 +喂 +葛 +永 +¥ +烟 +酒 +桦 +书 +砂 +蚝 +缉 +态 +瀚 +袄 +圳 +轻 +蛛 +超 +榧 +遛 +姒 +奘 +铮 +右 +荽 +望 +偻 +卡 +丶 +氰 +附 +做 +革 +索 +戚 +坨 +桷 +唁 +垅 +榻 +岐 +偎 +坛 +莨 +山 +殊 +微 +骇 +陈 +爨 +推 +嗝 +驹 +澡 +藁 +呤 +卤 +嘻 +糅 +逛 +侵 +郓 +酌 +德 +摇 +※ +鬃 +被 +慨 +殡 +羸 +昌 +泡 +戛 +鞋 +河 +宪 +沿 +玲 +鲨 +翅 +哽 +源 +铅 +语 +照 +邯 +址 +荃 +佬 +顺 +鸳 +町 +霭 +睾 +瓢 +夸 +椁 +晓 +酿 +痈 +咔 +侏 +券 +噎 +湍 +签 +嚷 +离 +午 +尚 +社 +锤 +背 +孟 +使 +浪 +缦 +潍 +鞅 +军 +姹 +驶 +笑 +鳟 +鲁 +》 +孽 +钜 +绿 +洱 +礴 +焯 +椰 +颖 +囔 +乌 +孔 +巴 +互 +性 +椽 +哞 +聘 +昨 +早 +暮 +胶 +炀 +隧 +低 +彗 +昝 +铁 +呓 +氽 +藉 +喔 +癖 +瑗 +姨 +权 +胱 +韦 +堑 +蜜 +酋 +楝 +砝 +毁 +靓 +歙 +锲 +究 +屋 +喳 +骨 +辨 +碑 +武 +鸠 +宫 +辜 +烊 +适 +坡 +殃 +培 +佩 +供 +走 +蜈 +迟 +翼 +况 +姣 +凛 +浔 +吃 +飘 +债 +犟 +金 +促 +苛 +崇 +坂 +莳 +畔 +绂 +兵 +蠕 +斋 +根 +砍 +亢 +欢 +恬 +崔 +剁 +餐 +榫 +快 +扶 +‖ +濒 +缠 +鳜 +当 +彭 +驭 +浦 +篮 +昀 +锆 +秸 +钳 +弋 +娣 +瞑 +夷 +龛 +苫 +拱 +致 +% +嵊 +障 +隐 +弑 +初 +娓 +抉 +汩 +累 +蓖 +" +唬 +助 +苓 +昙 +押 +毙 +破 +城 +郧 +逢 +嚏 +獭 +瞻 +溱 +婿 +赊 +跨 +恼 +璧 +萃 +姻 +貉 +灵 +炉 +密 +氛 +陶 +砸 +谬 +衔 +点 +琛 +沛 +枳 +层 +岱 +诺 +脍 +榈 +埂 +征 +冷 +裁 +打 +蹴 +素 +瘘 +逞 +蛐 +聊 +激 +腱 +萘 +踵 +飒 +蓟 +吆 +取 +咙 +簋 +涓 +矩 +曝 +挺 +揣 +座 +你 +史 +舵 +焱 +尘 +苏 +笈 +脚 +溉 +榨 +诵 +樊 +邓 +焊 +义 +庶 +儋 +蟋 +蒲 +赦 +呷 +杞 +诠 +豪 +还 +试 +颓 +茉 +太 +除 +紫 +逃 +痴 +草 +充 +鳕 +珉 +祗 +墨 +渭 +烩 +蘸 +慕 +璇 +镶 +穴 +嵘 +恶 +骂 +险 +绋 +幕 +碉 +肺 +戳 +刘 +潞 +秣 +纾 +潜 +銮 +洛 +须 +罘 +销 +瘪 +汞 +兮 +屉 +r +林 +厕 +质 +探 +划 +狸 +殚 +善 +煊 +烹 +〒 +锈 +逯 +宸 +辍 +泱 +柚 +袍 +远 +蹋 +嶙 +绝 +峥 +娥 +缍 +雀 +徵 +认 +镱 +谷 += +贩 +勉 +撩 +鄯 +斐 +洋 +非 +祚 +泾 +诒 +饿 +撬 +威 +晷 +搭 +芍 +锥 +笺 +蓦 +候 +琊 +档 +礁 +沼 +卵 +荠 +忑 +朝 +凹 +瑞 +头 +仪 +弧 +孵 +畏 +铆 +突 +衲 +车 +浩 +气 +茂 +悖 +厢 +枕 +酝 +戴 +湾 +邹 +飚 +攘 +锂 +写 +宵 +翁 +岷 +无 +喜 +丈 +挑 +嗟 +绛 +殉 +议 +槽 +具 +醇 +淞 +笃 +郴 +阅 +饼 +底 +壕 +砚 +弈 +询 +缕 +庹 +翟 +零 +筷 +暨 +舟 +闺 +甯 +撞 +麂 +茌 +蔼 +很 +珲 +捕 +棠 +角 +阉 +媛 +娲 +诽 +剿 +尉 +爵 +睬 +韩 +诰 +匣 +危 +糍 +镯 +立 +浏 +阳 +少 +盆 +舔 +擘 +匪 +申 +尬 +铣 +旯 +抖 +赘 +瓯 +居 +ˇ +哮 +游 +锭 +茏 +歌 +坏 +甚 +秒 +舞 +沙 +仗 +劲 +潺 +阿 +燧 +郭 +嗖 +霏 +忠 +材 +奂 +耐 +跺 +砀 +输 +岖 +媳 +氟 +极 +摆 +灿 +今 +扔 +腻 +枝 +奎 +药 +熄 +吨 +话 +q +额 +慑 +嘌 +协 +喀 +壳 +埭 +视 +著 +於 +愧 +陲 +翌 +峁 +颅 +佛 +腹 +聋 +侯 +咎 +叟 +秀 +颇 +存 +较 +罪 +哄 +岗 +扫 +栏 +钾 +羌 +己 +璨 +枭 +霉 +煌 +涸 +衿 +键 +镝 +益 +岢 +奏 +连 +夯 +睿 +冥 +均 +糖 +狞 +蹊 +稻 +爸 +刿 +胥 +煜 +丽 +肿 +璃 +掸 +跚 +灾 +垂 +樾 +濑 +乎 +莲 +窄 +犹 +撮 +战 +馄 +软 +络 +显 +鸢 +胸 +宾 +妲 +恕 +埔 +蝌 +份 +遇 +巧 +瞟 +粒 +恰 +剥 +桡 +博 +讯 +凯 +堇 +阶 +滤 +卖 +斌 +骚 +彬 +兑 +磺 +樱 +舷 +两 +娱 +福 +仃 +差 +找 +桁 +÷ +净 +把 +阴 +污 +戬 +雷 +碓 +蕲 +楚 +罡 +焖 +抽 +妫 +咒 +仑 +闱 +尽 +邑 +菁 +爱 +贷 +沥 +鞑 +牡 +嗉 +崴 +骤 +塌 +嗦 +订 +拮 +滓 +捡 +锻 +次 +坪 +杩 +臃 +箬 +融 +珂 +鹗 +宗 +枚 +降 +鸬 +妯 +阄 +堰 +盐 +毅 +必 +杨 +崃 +俺 +甬 +状 +莘 +货 +耸 +菱 +腼 +铸 +唏 +痤 +孚 +澳 +懒 +溅 +翘 +疙 +杷 +淼 +缙 +骰 +喊 +悉 +砻 +坷 +艇 +赁 +界 +谤 +纣 +宴 +晃 +茹 +归 +饭 +梢 +铡 +街 +抄 +肼 +鬟 +苯 +颂 +撷 +戈 +炒 +咆 +茭 +瘙 +负 +仰 +客 +琉 +铢 +封 +卑 +珥 +椿 +镧 +窨 +鬲 +寿 +御 +袤 +铃 +萎 +砖 +餮 +脒 +裳 +肪 +孕 +嫣 +馗 +嵇 +恳 +氯 +江 +石 +褶 +冢 +祸 +阻 +狈 +羞 +银 +靳 +透 +咳 +叼 +敷 +芷 +啥 +它 +瓤 +兰 +痘 +懊 +逑 +肌 +往 +捺 +坊 +甩 +呻 +〃 +沦 +忘 +膻 +祟 +菅 +剧 +崆 +智 +坯 +臧 +霍 +墅 +攻 +眯 +倘 +拢 +骠 +铐 +庭 +岙 +瓠 +′ +缺 +泥 +迢 +捶 +? +? +郏 +喙 +掷 +沌 +纯 +秘 +种 +听 +绘 +固 +螨 +团 +香 +盗 +妒 +埚 +蓝 +拖 +旱 +荞 +铀 +血 +遏 +汲 +辰 +叩 +拽 +幅 +硬 +惶 +桀 +漠 +措 +泼 +唑 +齐 +肾 +念 +酱 +虚 +屁 +耶 +旗 +砦 +闵 +婉 +馆 +拭 +绅 +韧 +忏 +窝 +醋 +葺 +顾 +辞 +倜 +堆 +辋 +逆 +玟 +贱 +疾 +董 +惘 +倌 +锕 +淘 +嘀 +莽 +俭 +笏 +绑 +鲷 +杈 +择 +蟀 +粥 +嗯 +驰 +逾 +案 +谪 +褓 +胫 +哩 +昕 +颚 +鲢 +绠 +躺 +鹄 +崂 +儒 +俨 +丝 +尕 +泌 +啊 +萸 +彰 +幺 +吟 +骄 +苣 +弦 +脊 +瑰 +〈 +诛 +镁 +析 +闪 +剪 +侧 +哟 +框 +螃 +守 +嬗 +燕 +狭 +铈 +缮 +概 +迳 +痧 +鲲 +俯 +售 +笼 +痣 +扉 +挖 +满 +咋 +援 +邱 +扇 +歪 +便 +玑 +绦 +峡 +蛇 +叨 +〖 +泽 +胃 +斓 +喋 +怂 +坟 +猪 +该 +蚬 +炕 +弥 +赞 +棣 +晔 +娠 +挲 +狡 +创 +疖 +铕 +镭 +稷 +挫 +弭 +啾 +翔 +粉 +履 +苘 +哦 +楼 +秕 +铂 +土 +锣 +瘟 +挣 +栉 +习 +享 +桢 +袅 +磨 +桂 +谦 +延 +坚 +蔚 +噗 +署 +谟 +猬 +钎 +恐 +嬉 +雒 +倦 +衅 +亏 +璩 +睹 +刻 +殿 +王 +算 +雕 +麻 +丘 +柯 +骆 +丸 +塍 +谚 +添 +鲈 +垓 +桎 +蚯 +芥 +予 +飕 +镦 +谌 +窗 +醚 +菀 +亮 +搪 +莺 +蒿 +羁 +足 +J +真 +轶 +悬 +衷 +靛 +翊 +掩 +哒 +炅 +掐 +冼 +妮 +l +谐 +稚 +荆 +擒 +犯 +陵 +虏 +浓 +崽 +刍 +陌 +傻 +孜 +千 +靖 +演 +矜 +钕 +煽 +杰 +酗 +渗 +伞 +栋 +俗 +泫 +戍 +罕 +沾 +疽 +灏 +煦 +芬 +磴 +叱 +阱 +榉 +湃 +蜀 +叉 +醒 +彪 +租 +郡 +篷 +屎 +良 +垢 +隗 +弱 +陨 +峪 +砷 +掴 +颁 +胎 +雯 +绵 +贬 +沐 +撵 +隘 +篙 +暖 +曹 +陡 +栓 +填 +臼 +彦 +瓶 +琪 +潼 +哪 +鸡 +摩 +啦 +俟 +锋 +域 +耻 +蔫 +疯 +纹 +撇 +毒 +绶 +痛 +酯 +忍 +爪 +赳 +歆 +嘹 +辕 +烈 +册 +朴 +钱 +吮 +毯 +癜 +娃 +谀 +邵 +厮 +炽 +璞 +邃 +丐 +追 +词 +瓒 +忆 +轧 +芫 +谯 +喷 +弟 +半 +冕 +裙 +掖 +墉 +绮 +寝 +苔 +势 +顷 +褥 +切 +衮 +君 +佳 +嫒 +蚩 +霞 +佚 +洙 +逊 +镖 +暹 +唛 +& +殒 +顶 +碗 +獗 +轭 +铺 +蛊 +废 +恹 +汨 +崩 +珍 +那 +杵 +曲 +纺 +夏 +薰 +傀 +闳 +淬 +姘 +舀 +拧 +卷 +楂 +恍 +讪 +厩 +寮 +篪 +赓 +乘 +灭 +盅 +鞣 +沟 +慎 +挂 +饺 +鼾 +杳 +树 +缨 +丛 +絮 +娌 +臻 +嗳 +篡 +侩 +述 +衰 +矛 +圈 +蚜 +匕 +筹 +匿 +濞 +晨 +叶 +骋 +郝 +挚 +蚴 +滞 +增 +侍 +描 +瓣 +吖 +嫦 +蟒 +匾 +圣 +赌 +毡 +癞 +恺 +百 +曳 +需 +篓 +肮 +庖 +帏 +卿 +驿 +遗 +蹬 +鬓 +骡 +歉 +芎 +胳 +屐 +禽 +烦 +晌 +寄 +媾 +狄 +翡 +苒 +船 +廉 +终 +痞 +殇 +々 +畦 +饶 +改 +拆 +悻 +萄 +£ +瓿 +乃 +訾 +桅 +匮 +溧 +拥 +纱 +铍 +骗 +蕃 +龋 +缬 +父 +佐 +疚 +栎 +醍 +掳 +蓄 +x +惆 +颜 +鲆 +榆 +〔 +猎 +敌 +暴 +谥 +鲫 +贾 +罗 +玻 +缄 +扦 +芪 +癣 +落 +徒 +臾 +恿 +猩 +托 +邴 +肄 +牵 +春 +陛 +耀 +刊 +拓 +蓓 +邳 +堕 +寇 +枉 +淌 +啡 +湄 +兽 +酷 +萼 +碚 +濠 +萤 +夹 +旬 +戮 +梭 +琥 +椭 +昔 +勺 +蜊 +绐 +晚 +孺 +僵 +宣 +摄 +冽 +旨 +萌 +忙 +蚤 +眉 +噼 +蟑 +付 +契 +瓜 +悼 +颡 +壁 +曾 +窕 +颢 +澎 +仿 +俑 +浑 +嵌 +浣 +乍 +碌 +褪 +乱 +蔟 +隙 +玩 +剐 +葫 +箫 +纲 +围 +伐 +决 +伙 +漩 +瑟 +刑 +肓 +镳 +缓 +蹭 +氨 +皓 +典 +畲 +坍 +铑 +檐 +塑 +洞 +倬 +储 +胴 +淳 +戾 +吐 +灼 +惺 +妙 +毕 +珐 +缈 +虱 +盖 +羰 +鸿 +磅 +谓 +髅 +娴 +苴 +唷 +蚣 +霹 +抨 +贤 +唠 +犬 +誓 +逍 +庠 +逼 +麓 +籼 +釉 +呜 +碧 +秧 +氩 +摔 +霄 +穸 +纨 +辟 +妈 +映 +完 +牛 +缴 +嗷 +炊 +恩 +荔 +茆 +掉 +紊 +慌 +莓 +羟 +阙 +萁 +磐 +另 +蕹 +辱 +鳐 +湮 +吡 +吩 +唐 +睦 +垠 +舒 +圜 +冗 +瞿 +溺 +芾 +囱 +匠 +僳 +汐 +菩 +饬 +漓 +黑 +霰 +浸 +濡 +窥 +毂 +蒡 +兢 +驻 +鹉 +芮 +诙 +迫 +雳 +厂 +忐 +臆 +猴 +鸣 +蚪 +栈 +箕 +羡 +渐 +莆 +捍 +眈 +哓 +趴 +蹼 +埕 +嚣 +骛 +宏 +淄 +斑 +噜 +严 +瑛 +垃 +椎 +诱 +压 +庾 +绞 +焘 +廿 +抡 +迄 +棘 +夫 +纬 +锹 +眨 +瞌 +侠 +脐 +竞 +瀑 +孳 +骧 +遁 +姜 +颦 +荪 +滚 +萦 +伪 +逸 +粳 +爬 +锁 +矣 +役 +趣 +洒 +颔 +诏 +逐 +奸 +甭 +惠 +攀 +蹄 +泛 +尼 +拼 +阮 +鹰 +亚 +颈 +惑 +勒 +〉 +际 +肛 +爷 +刚 +钨 +丰 +养 +冶 +鲽 +辉 +蔻 +画 +覆 +皴 +妊 +麦 +返 +醉 +皂 +擀 +〗 +酶 +凑 +粹 +悟 +诀 +硖 +港 +卜 +z +杀 +涕 +± +舍 +铠 +抵 +弛 +段 +敝 +镐 +奠 +拂 +轴 +跛 +袱 +e +t +沉 +菇 +俎 +薪 +峦 +秭 +蟹 +历 +盟 +菠 +寡 +液 +肢 +喻 +染 +裱 +悱 +抱 +氙 +赤 +捅 +猛 +跑 +氮 +谣 +仁 +尺 +辊 +窍 +烙 +衍 +架 +擦 +倏 +璐 +瑁 +币 +楞 +胖 +夔 +趸 +邛 +惴 +饕 +虔 +蝎 +§ +哉 +贝 +宽 +辫 +炮 +扩 +饲 +籽 +魏 +菟 +锰 +伍 +猝 +末 +琳 +哚 +蛎 +邂 +呀 +姿 +鄞 +却 +歧 +仙 +恸 +椐 +森 +牒 +寤 +袒 +婆 +虢 +雅 +钉 +朵 +贼 +欲 +苞 +寰 +故 +龚 +坭 +嘘 +咫 +礼 +硷 +兀 +睢 +汶 +’ +铲 +烧 +绕 +诃 +浃 +钿 +哺 +柜 +讼 +颊 +璁 +腔 +洽 +咐 +脲 +簌 +筠 +镣 +玮 +鞠 +谁 +兼 +姆 +挥 +梯 +蝴 +谘 +漕 +刷 +躏 +宦 +弼 +b +垌 +劈 +麟 +莉 +揭 +笙 +渎 +仕 +嗤 +仓 +配 +怏 +抬 +错 +泯 +镊 +孰 +猿 +邪 +仍 +秋 +鼬 +壹 +歇 +吵 +炼 +< +尧 +射 +柬 +廷 +胧 +霾 +凳 +隋 +肚 +浮 +梦 +祥 +株 +堵 +退 +L +鹫 +跎 +凶 +毽 +荟 +炫 +栩 +玳 +甜 +沂 +鹿 +顽 +伯 +爹 +赔 +蛴 +徐 +匡 +欣 +狰 +缸 +雹 +蟆 +疤 +默 +沤 +啜 +痂 +衣 +禅 +w +i +h +辽 +葳 +黝 +钗 +停 +沽 +棒 +馨 +颌 +肉 +吴 +硫 +悯 +劾 +娈 +马 +啧 +吊 +悌 +镑 +峭 +帆 +瀣 +涉 +咸 +疸 +滋 +泣 +翦 +拙 +癸 +钥 +蜒 ++ +尾 +庄 +凝 +泉 +婢 +渴 +谊 +乞 +陆 +锉 +糊 +鸦 +淮 +I +B +N +晦 +弗 +乔 +庥 +葡 +尻 +席 +橡 +傣 +渣 +拿 +惩 +麋 +斛 +缃 +矮 +蛏 +岘 +鸽 +姐 +膏 +催 +奔 +镒 +喱 +蠡 +摧 +钯 +胤 +柠 +拐 +璋 +鸥 +卢 +荡 +倾 +^ +_ +珀 +逄 +萧 +塾 +掇 +贮 +笆 +聂 +圃 +冲 +嵬 +M +滔 +笕 +值 +炙 +偶 +蜱 +搐 +梆 +汪 +蔬 +腑 +鸯 +蹇 +敞 +绯 +仨 +祯 +谆 +梧 +糗 +鑫 +啸 +豺 +囹 +猾 +巢 +柄 +瀛 +筑 +踌 +沭 +暗 +苁 +鱿 +蹉 +脂 +蘖 +牢 +热 +木 +吸 +溃 +宠 +序 +泞 +偿 +拜 +檩 +厚 +朐 +毗 +螳 +吞 +媚 +朽 +担 +蝗 +橘 +畴 +祈 +糟 +盱 +隼 +郜 +惜 +珠 +裨 +铵 +焙 +琚 +唯 +咚 +噪 +骊 +丫 +滢 +勤 +棉 +呸 +咣 +淀 +隔 +蕾 +窈 +饨 +挨 +煅 +短 +匙 +粕 +镜 +赣 +撕 +墩 +酬 +馁 +豌 +颐 +抗 +酣 +氓 +佑 +搁 +哭 +递 +耷 +涡 +桃 +贻 +碣 +截 +瘦 +昭 +镌 +蔓 +氚 +甲 +猕 +蕴 +蓬 +散 +拾 +纛 +狼 +猷 +铎 +埋 +旖 +矾 +讳 +囊 +糜 +迈 +粟 +蚂 +紧 +鲳 +瘢 +栽 +稼 +羊 +锄 +斟 +睁 +桥 +瓮 +蹙 +祉 +醺 +鼻 +昱 +剃 +跳 +篱 +跷 +蒜 +翎 +宅 +晖 +嗑 +壑 +峻 +癫 +屏 +狠 +陋 +袜 +途 +憎 +祀 +莹 +滟 +佶 +溥 +臣 +约 +盛 +峰 +磁 +慵 +婪 +拦 +莅 +朕 +鹦 +粲 +裤 +哎 +疡 +嫖 +琵 +窟 +堪 +谛 +嘉 +儡 +鳝 +斩 +郾 +驸 +酊 +妄 +胜 +贺 +徙 +傅 +噌 +钢 +栅 +庇 +恋 +匝 +巯 +邈 +尸 +锚 +粗 +佟 +蛟 +薹 +纵 +蚊 +郅 +绢 +锐 +苗 +俞 +篆 +淆 +膀 +鲜 +煎 +诶 +秽 +寻 +涮 +刺 +怀 +噶 +巨 +褰 +魅 +灶 +灌 +桉 +藕 +谜 +舸 +薄 +搀 +恽 +借 +牯 +痉 +渥 +愿 +亓 +耘 +杠 +柩 +锔 +蚶 +钣 +珈 +喘 +蹒 +幽 +赐 +稗 +晤 +莱 +泔 +扯 +肯 +菪 +裆 +腩 +豉 +疆 +骜 +腐 +倭 +珏 +唔 +粮 +亡 +润 +慰 +伽 +橄 +玄 +誉 +醐 +胆 +龊 +粼 +塬 +陇 +彼 +削 +嗣 +绾 +芽 +妗 +垭 +瘴 +爽 +薏 +寨 +龈 +泠 +弹 +赢 +漪 +猫 +嘧 +涂 +恤 +圭 +茧 +烽 +屑 +痕 +巾 +赖 +荸 +凰 +腮 +畈 +亵 +蹲 +偃 +苇 +澜 +艮 +换 +骺 +烘 +苕 +梓 +颉 +肇 +哗 +悄 +氤 +涠 +葬 +屠 +鹭 +植 +竺 +佯 +诣 +鲇 +瘀 +鲅 +邦 +移 +滁 +冯 +耕 +癔 +戌 +茬 +沁 +巩 +悠 +湘 +洪 +痹 +锟 +循 +谋 +腕 +鳃 +钠 +捞 +焉 +迎 +碱 +伫 +急 +榷 +奈 +邝 +卯 +辄 +皲 +卟 +醛 +畹 +忧 +稳 +雄 +昼 +缩 +阈 +睑 +扌 +耗 +曦 +涅 +捏 +瞧 +邕 +淖 +漉 +铝 +耦 +禹 +湛 +喽 +莼 +琅 +诸 +苎 +纂 +硅 +始 +嗨 +傥 +燃 +臂 +赅 +嘈 +呆 +贵 +屹 +壮 +肋 +亍 +蚀 +卅 +豹 +腆 +邬 +迭 +浊 +} +童 +螂 +捐 +圩 +勐 +触 +寞 +汊 +壤 +荫 +膺 +渌 +芳 +懿 +遴 +螈 +泰 +蓼 +蛤 +茜 +舅 +枫 +朔 +膝 +眙 +避 +梅 +判 +鹜 +璜 +牍 +缅 +垫 +藻 +黔 +侥 +惚 +懂 +踩 +腰 +腈 +札 +丞 +唾 +慈 +顿 +摹 +荻 +琬 +~ +斧 +沈 +滂 +胁 +胀 +幄 +莜 +Z +匀 +鄄 +掌 +绰 +茎 +焚 +赋 +萱 +谑 +汁 +铒 +瞎 +夺 +蜗 +野 +娆 +冀 +弯 +篁 +懵 +灞 +隽 +芡 +脘 +俐 +辩 +芯 +掺 +喏 +膈 +蝈 +觐 +悚 +踹 +蔗 +熠 +鼠 +呵 +抓 +橼 +峨 +畜 +缔 +禾 +崭 +弃 +熊 +摒 +凸 +拗 +穹 +蒙 +抒 +祛 +劝 +闫 +扳 +阵 +醌 +踪 +喵 +侣 +搬 +仅 +荧 +赎 +蝾 +琦 +买 +婧 +瞄 +寓 +皎 +冻 +赝 +箩 +莫 +瞰 +郊 +笫 +姝 +筒 +枪 +遣 +煸 +袋 +舆 +痱 +涛 +母 +〇 +启 +践 +耙 +绲 +盘 +遂 +昊 +搞 +槿 +诬 +纰 +泓 +惨 +檬 +亻 +越 +C +o +憩 +熵 +祷 +钒 +暧 +塔 +阗 +胰 +咄 +娶 +魔 +琶 +钞 +邻 +扬 +杉 +殴 +咽 +弓 +〆 +髻 +】 +吭 +揽 +霆 +拄 +殖 +脆 +彻 +岩 +芝 +勃 +辣 +剌 +钝 +嘎 +甄 +佘 +皖 +伦 +授 +徕 +憔 +挪 +皇 +庞 +稔 +芜 +踏 +溴 +兖 +卒 +擢 +饥 +鳞 +煲 +‰ +账 +颗 +叻 +斯 +捧 +鳍 +琮 +讹 +蛙 +纽 +谭 +酸 +兔 +莒 +睇 +伟 +觑 +羲 +嗜 +宜 +褐 +旎 +辛 +卦 +诘 +筋 +鎏 +溪 +挛 +熔 +阜 +晰 +鳅 +丢 +奚 +灸 +呱 +献 +陉 +黛 +鸪 +甾 +萨 +疮 +拯 +洲 +疹 +辑 +叙 +恻 +谒 +允 +柔 +烂 +氏 +逅 +漆 +拎 +惋 +扈 +湟 +纭 +啕 +掬 +擞 +哥 +忽 +涤 +鸵 +靡 +郗 +瓷 +扁 +廊 +怨 +雏 +钮 +敦 +E +懦 +憋 +汀 +拚 +啉 +腌 +岸 +f +痼 +瞅 +尊 +咀 +眩 +飙 +忌 +仝 +迦 +熬 +毫 +胯 +篑 +茄 +腺 +凄 +舛 +碴 +锵 +诧 +羯 +後 +漏 +汤 +宓 +仞 +蚁 +壶 +谰 +皑 +铄 +棰 +罔 +辅 +晶 +苦 +牟 +闽 +\ +烃 +饮 +聿 +丙 +蛳 +朱 +煤 +涔 +鳖 +犁 +罐 +荼 +砒 +淦 +妤 +黏 +戎 +孑 +婕 +瑾 +戢 +钵 +枣 +捋 +砥 +衩 +狙 +桠 +稣 +阎 +肃 +梏 +诫 +孪 +昶 +婊 +衫 +嗔 +侃 +塞 +蜃 +樵 +峒 +貌 +屿 +欺 +缫 +阐 +栖 +诟 +珞 +荭 +吝 +萍 +嗽 +恂 +啻 +蜴 +磬 +峋 +俸 +豫 +谎 +徊 +镍 +韬 +魇 +晴 +U +囟 +猜 +蛮 +坐 +囿 +伴 +亭 +肝 +佗 +蝠 +妃 +胞 +滩 +榴 +氖 +垩 +苋 +砣 +扪 +馏 +姓 +轩 +厉 +夥 +侈 +禀 +垒 +岑 +赏 +钛 +辐 +痔 +披 +纸 +碳 +“ +坞 +蠓 +挤 +荥 +沅 +悔 +铧 +帼 +蒌 +蝇 +a +p +y +n +g +哀 +浆 +瑶 +凿 +桶 +馈 +皮 +奴 +苜 +佤 +伶 +晗 +铱 +炬 +优 +弊 +氢 +恃 +甫 +攥 +端 +锌 +灰 +稹 +炝 +曙 +邋 +亥 +眶 +碾 +拉 +萝 +绔 +捷 +浍 +腋 +姑 +菖 +凌 +涞 +麽 +锢 +桨 +潢 +绎 +镰 +殆 +锑 +渝 +铬 +困 +绽 +觎 +匈 +糙 +暑 +裹 +鸟 +盔 +肽 +迷 +綦 +『 +亳 +佝 +俘 +钴 +觇 +骥 +仆 +疝 +跪 +婶 +郯 +瀹 +唉 +脖 +踞 +针 +晾 +忒 +扼 +瞩 +叛 +椒 +疟 +嗡 +邗 +肆 +跆 +玫 +忡 +捣 +咧 +唆 +艄 +蘑 +潦 +笛 +阚 +沸 +泻 +掊 +菽 +贫 +斥 +髂 +孢 +镂 +赂 +麝 +鸾 +屡 +衬 +苷 +恪 +叠 +希 +粤 +爻 +喝 +茫 +惬 +郸 +绻 +庸 +撅 +碟 +宄 +妹 +膛 +叮 +饵 +崛 +嗲 +椅 +冤 +搅 +咕 +敛 +尹 +垦 +闷 +蝉 +霎 +勰 +败 +蓑 +泸 +肤 +鹌 +幌 +焦 +浠 +鞍 +刁 +舰 +乙 +竿 +裔 +。 +茵 +函 +伊 +兄 +丨 +娜 +匍 +謇 +莪 +宥 +似 +蝽 +翳 +酪 +翠 +粑 +薇 +祢 +骏 +赠 +叫 +Q +噤 +噻 +竖 +芗 +莠 +潭 +俊 +羿 +耜 +O +郫 +趁 +嗪 +囚 +蹶 +芒 +洁 +笋 +鹑 +敲 +硝 +啶 +堡 +渲 +揩 +』 +携 +宿 +遒 +颍 +扭 +棱 +割 +萜 +蔸 +葵 +琴 +捂 +饰 +衙 +耿 +掠 +募 +岂 +窖 +涟 +蔺 +瘤 +柞 +瞪 +怜 +匹 +距 +楔 +炜 +哆 +秦 +缎 +幼 +茁 +绪 +痨 +恨 +楸 +娅 +瓦 +桩 +雪 +嬴 +伏 +榔 +妥 +铿 +拌 +眠 +雍 +缇 +‘ +卓 +搓 +哌 +觞 +噩 +屈 +哧 +髓 +咦 +巅 +娑 +侑 +淫 +膳 +祝 +勾 +姊 +莴 +胄 +疃 +薛 +蜷 +胛 +巷 +芙 +芋 +熙 +闰 +勿 +窃 +狱 +剩 +钏 +幢 +陟 +铛 +慧 +靴 +耍 +k +浙 +浇 +飨 +惟 +绗 +祜 +澈 +啼 +咪 +磷 +摞 +诅 +郦 +抹 +跃 +壬 +吕 +肖 +琏 +颤 +尴 +剡 +抠 +凋 +赚 +泊 +津 +宕 +殷 +倔 +氲 +漫 +邺 +涎 +怠 +$ +垮 +荬 +遵 +俏 +叹 +噢 +饽 +蜘 +孙 +筵 +疼 +鞭 +羧 +牦 +箭 +潴 +c +眸 +祭 +髯 +啖 +坳 +愁 +芩 +驮 +倡 +巽 +穰 +沃 +胚 +怒 +凤 +槛 +剂 +趵 +嫁 +v +邢 +灯 +鄢 +桐 +睽 +檗 +锯 +槟 +婷 +嵋 +圻 +诗 +蕈 +颠 +遭 +痢 +芸 +怯 +馥 +竭 +锗 +徜 +恭 +遍 +籁 +剑 +嘱 +苡 +龄 +僧 +桑 +潸 +弘 +澶 +楹 +悲 +讫 +愤 +腥 +悸 +谍 +椹 +呢 +桓 +葭 +攫 +阀 +翰 +躲 +敖 +柑 +郎 +笨 +橇 +呃 +魁 +燎 +脓 +葩 +磋 +垛 +玺 +狮 +沓 +砜 +蕊 +锺 +罹 +蕉 +翱 +虐 +闾 +巫 +旦 +茱 +嬷 +枯 +鹏 +贡 +芹 +汛 +矫 +绁 +拣 +禺 +佃 +讣 +舫 +惯 +乳 +趋 +疲 +挽 +岚 +虾 +衾 +蠹 +蹂 +飓 +氦 +铖 +孩 +稞 +瑜 +壅 +掀 +勘 +妓 +畅 +髋 +W +庐 +牲 +蓿 +榕 +练 +垣 +唱 +邸 +菲 +昆 +婺 +穿 +绡 +麒 +蚱 +掂 +愚 +泷 +涪 +漳 +妩 +娉 +榄 +讷 +觅 +旧 +藤 +煮 +呛 +柳 +腓 +叭 +庵 +烷 +阡 +罂 +蜕 +擂 +猖 +咿 +媲 +脉 +【 +沏 +貅 +黠 +熏 +哲 +烁 +坦 +酵 +兜 +× +潇 +撒 +剽 +珩 +圹 +乾 +摸 +樟 +帽 +嗒 +襄 +魂 +轿 +憬 +锡 +〕 +喃 +皆 +咖 +隅 +脸 +残 +泮 +袂 +鹂 +珊 +囤 +捆 +咤 +误 +徨 +闹 +淙 +芊 +淋 +怆 +囗 +拨 +梳 +渤 +R +G +绨 +蚓 +婀 +幡 +狩 +麾 +谢 +唢 +裸 +旌 +伉 +纶 +裂 +驳 +砼 +咛 +澄 +樨 +蹈 +宙 +澍 +倍 +貔 +操 +勇 +蟠 +摈 +砧 +虬 +够 +缁 +悦 +藿 +撸 +艹 +摁 +淹 +豇 +虎 +榭 +ˉ +吱 +d +° +喧 +荀 +踱 +侮 +奋 +偕 +饷 +犍 +惮 +坑 +璎 +徘 +宛 +妆 +袈 +倩 +窦 +昂 +荏 +乖 +K +怅 +撰 +鳙 +牙 +袁 +酞 +X +痿 +琼 +闸 +雁 +趾 +荚 +虻 +涝 +《 +杏 +韭 +偈 +烤 +绫 +鞘 +卉 +症 +遢 +蓥 +诋 +杭 +荨 +匆 +竣 +簪 +辙 +敕 +虞 +丹 +缭 +咩 +黟 +m +淤 +瑕 +咂 +铉 +硼 +茨 +嶂 +痒 +畸 +敬 +涿 +粪 +窘 +熟 +叔 +嫔 +盾 +忱 +裘 +憾 +梵 +赡 +珙 +咯 +娘 +庙 +溯 +胺 +葱 +痪 +摊 +荷 +卞 +乒 +髦 +寐 +铭 +坩 +胗 +枷 +爆 +溟 +嚼 +羚 +砬 +轨 +惊 +挠 +罄 +竽 +菏 +氧 +浅 +楣 +盼 +枢 +炸 +阆 +杯 +谏 +噬 +淇 +渺 +俪 +秆 +墓 +泪 +跻 +砌 +痰 +垡 +渡 +耽 +釜 +讶 +鳎 +煞 +呗 +韶 +舶 +绷 +鹳 +缜 +旷 +铊 +皱 +龌 +檀 +霖 +奄 +槐 +艳 +蝶 +旋 +哝 +赶 +骞 +蚧 +腊 +盈 +丁 +` +蜚 +矸 +蝙 +睨 +嚓 +僻 +鬼 +醴 +夜 +彝 +磊 +笔 +拔 +栀 +糕 +厦 +邰 +纫 +逭 +纤 +眦 +膊 +馍 +躇 +烯 +蘼 +冬 +诤 +暄 +骶 +哑 +瘠 +」 +臊 +丕 +愈 +咱 +螺 +擅 +跋 +搏 +硪 +谄 +笠 +淡 +嘿 +骅 +谧 +鼎 +皋 +姚 +歼 +蠢 +驼 +耳 +胬 +挝 +涯 +狗 +蒽 +孓 +犷 +凉 +芦 +箴 +铤 +孤 +嘛 +坤 +V +茴 +朦 +挞 +尖 +橙 +诞 +搴 +碇 +洵 +浚 +帚 +蜍 +漯 +柘 +嚎 +讽 +芭 +荤 +咻 +祠 +秉 +跖 +埃 +吓 +糯 +眷 +馒 +惹 +娼 +鲑 +嫩 +讴 +轮 +瞥 +靶 +褚 +乏 +缤 +宋 +帧 +删 +驱 +碎 +扑 +俩 +俄 +偏 +涣 +竹 +噱 +皙 +佰 +渚 +唧 +斡 +# +镉 +刀 +崎 +筐 +佣 +夭 +贰 +肴 +峙 +哔 +艿 +匐 +牺 +镛 +缘 +仡 +嫡 +劣 +枸 +堀 +梨 +簿 +鸭 +蒸 +亦 +稽 +浴 +{ +衢 +束 +槲 +j +阁 +揍 +疥 +棋 +潋 +聪 +窜 +乓 +睛 +插 +冉 +阪 +苍 +搽 +「 +蟾 +螟 +幸 +仇 +樽 +撂 +慢 +跤 +幔 +俚 +淅 +覃 +觊 +溶 +妖 +帛 +侨 +曰 +妾 +泗 +· +: +瀘 +風 +Ë +( +) +∶ +紅 +紗 +瑭 +雲 +頭 +鶏 +財 +許 +• +¥ +樂 +焗 +麗 +— +; +滙 +東 +榮 +繪 +興 +… +門 +業 +π +楊 +國 +顧 +é +盤 +寳 +Λ +龍 +鳳 +島 +誌 +緣 +結 +銭 +萬 +勝 +祎 +璟 +優 +歡 +臨 +時 +購 += +★ +藍 +昇 +鐵 +觀 +勅 +農 +聲 +畫 +兿 +術 +發 +劉 +記 +專 +耑 +園 +書 +壴 +種 +Ο +● +褀 +號 +銀 +匯 +敟 +锘 +葉 +橪 +廣 +進 +蒄 +鑽 +阝 +祙 +貢 +鍋 +豊 +夬 +喆 +團 +閣 +開 +燁 +賓 +館 +酡 +沔 +順 ++ +硚 +劵 +饸 +陽 +車 +湓 +復 +萊 +氣 +軒 +華 +堃 +迮 +纟 +戶 +馬 +學 +裡 +電 +嶽 +獨 +マ +シ +サ +ジ +燘 +袪 +環 +❤ +臺 +灣 +専 +賣 +孖 +聖 +攝 +線 +▪ +α +傢 +俬 +夢 +達 +莊 +喬 +貝 +薩 +劍 +羅 +壓 +棛 +饦 +尃 +璈 +囍 +醫 +G +I +A +# +N +鷄 +髙 +嬰 +啓 +約 +隹 +潔 +賴 +藝 +~ +寶 +籣 +麺 +  +嶺 +√ +義 +網 +峩 +長 +∧ +魚 +機 +構 +② +鳯 +偉 +L +B +㙟 +畵 +鴿 +' +詩 +溝 +嚞 +屌 +藔 +佧 +玥 +蘭 +織 +1 +3 +9 +0 +7 +點 +砭 +鴨 +鋪 +銘 +廳 +弍 +‧ +創 +湯 +坶 +℃ +卩 +骝 +& +烜 +荘 +當 +潤 +扞 +係 +懷 +碶 +钅 +蚨 +讠 +☆ +叢 +爲 +埗 +涫 +塗 +→ +楽 +現 +鯨 +愛 +瑪 +鈺 +忄 +悶 +藥 +飾 +樓 +視 +孬 +ㆍ +燚 +苪 +師 +① +丼 +锽 +│ +韓 +標 +è +兒 +閏 +匋 +張 +漢 +Ü +髪 +會 +閑 +檔 +習 +裝 +の +峯 +菘 +輝 +И +雞 +釣 +億 +浐 +K +O +R +8 +H +E +P +T +W +D +S +C +M +F +姌 +饹 +» +晞 +廰 +ä +嵯 +鷹 +負 +飲 +絲 +冚 +楗 +澤 +綫 +區 +❋ +← +質 +靑 +揚 +③ +滬 +統 +産 +協 +﹑ +乸 +畐 +經 +運 +際 +洺 +岽 +為 +粵 +諾 +崋 +豐 +碁 +ɔ +V +2 +6 +齋 +誠 +訂 +´ +勑 +雙 +陳 +無 +í +泩 +媄 +夌 +刂 +i +c +t +o +r +a +嘢 +耄 +燴 +暃 +壽 +媽 +靈 +抻 +體 +唻 +É +冮 +甹 +鎮 +錦 +ʌ +蜛 +蠄 +尓 +駕 +戀 +飬 +逹 +倫 +貴 +極 +Я +Й +寬 +磚 +嶪 +郎 +職 +| +間 +n +d +剎 +伈 +課 +飛 +橋 +瘊 +№ +譜 +骓 +圗 +滘 +縣 +粿 +咅 +養 +濤 +彳 +® +% +Ⅱ +啰 +㴪 +見 +矞 +薬 +糁 +邨 +鲮 +顔 +罱 +З +選 +話 +贏 +氪 +俵 +競 +瑩 +繡 +枱 +β +綉 +á +獅 +爾 +™ +麵 +戋 +淩 +徳 +個 +劇 +場 +務 +簡 +寵 +h +實 +膠 +轱 +圖 +築 +嘣 +樹 +㸃 +營 +耵 +孫 +饃 +鄺 +飯 +麯 +遠 +輸 +坫 +孃 +乚 +閃 +鏢 +㎡ +題 +廠 +關 +↑ +爺 +將 +軍 +連 +篦 +覌 +參 +箸 +- +窠 +棽 +寕 +夀 +爰 +歐 +呙 +閥 +頡 +熱 +雎 +垟 +裟 +凬 +勁 +帑 +馕 +夆 +疌 +枼 +馮 +貨 +蒤 +樸 +彧 +旸 +靜 +龢 +暢 +㐱 +鳥 +珺 +鏡 +灡 +爭 +堷 +廚 +Ó +騰 +診 +┅ +蘇 +褔 +凱 +頂 +豕 +亞 +帥 +嘬 +⊥ +仺 +桖 +複 +饣 +絡 +穂 +顏 +棟 +納 +▏ +濟 +親 +設 +計 +攵 +埌 +烺 +ò +頤 +燦 +蓮 +撻 +節 +講 +濱 +濃 +娽 +洳 +朿 +燈 +鈴 +護 +膚 +铔 +過 +補 +Z +U +5 +4 +坋 +闿 +䖝 +餘 +缐 +铞 +貿 +铪 +桼 +趙 +鍊 +[ +㐂 +垚 +菓 +揸 +捲 +鐘 +滏 +𣇉 +爍 +輪 +燜 +鴻 +鮮 +動 +鹞 +鷗 +丄 +慶 +鉌 +翥 +飮 +腸 +⇋ +漁 +覺 +來 +熘 +昴 +翏 +鲱 +圧 +鄉 +萭 +頔 +爐 +嫚 +г +貭 +類 +聯 +幛 +輕 +訓 +鑒 +夋 +锨 +芃 +珣 +䝉 +扙 +嵐 +銷 +處 +ㄱ +語 +誘 +苝 +歸 +儀 +燒 +楿 +內 +粢 +葒 +奧 +麥 +礻 +滿 +蠔 +穵 +瞭 +態 +鱬 +榞 +硂 +鄭 +黃 +煙 +祐 +奓 +逺 +* +瑄 +獲 +聞 +薦 +讀 +這 +樣 +決 +問 +啟 +們 +執 +説 +轉 +單 +隨 +唘 +帶 +倉 +庫 +還 +贈 +尙 +皺 +■ +餅 +產 +○ +∈ +報 +狀 +楓 +賠 +琯 +嗮 +禮 +` +傳 +> +≤ +嗞 +Φ +≥ +換 +咭 +∣ +↓ +曬 +ε +応 +寫 +″ +終 +様 +純 +費 +療 +聨 +凍 +壐 +郵 +ü +黒 +∫ +製 +塊 +調 +軽 +確 +撃 +級 +馴 +Ⅲ +涇 +繹 +數 +碼 +證 +狒 +処 +劑 +< +晧 +賀 +衆 +] +櫥 +兩 +陰 +絶 +對 +鯉 +憶 +◎ +p +e +Y +蕒 +煖 +頓 +測 +試 +鼽 +僑 +碩 +妝 +帯 +≈ +鐡 +舖 +權 +喫 +倆 +ˋ +該 +悅 +ā +俫 +. +f +s +b +m +k +g +u +j +貼 +淨 +濕 +針 +適 +備 +l +/ +給 +謢 +強 +觸 +衛 +與 +⊙ +$ +緯 +變 +⑴ +⑵ +⑶ +㎏ +殺 +∩ +幚 +─ +價 +▲ +離 +ú +ó +飄 +烏 +関 +閟 +﹝ +﹞ +邏 +輯 +鍵 +驗 +訣 +導 +歷 +屆 +層 +▼ +儱 +錄 +熳 +ē +艦 +吋 +錶 +辧 +飼 +顯 +④ +禦 +販 +気 +対 +枰 +閩 +紀 +幹 +瞓 +貊 +淚 +△ +眞 +墊 +Ω +獻 +褲 +縫 +緑 +亜 +鉅 +餠 +{ +} +◆ +蘆 +薈 +█ +◇ +溫 +彈 +晳 +粧 +犸 +穩 +訊 +崬 +凖 +熥 +П +舊 +條 +紋 +圍 +Ⅳ +筆 +尷 +難 +雜 +錯 +綁 +識 +頰 +鎖 +艶 +□ +殁 +殼 +⑧ +├ +▕ +鵬 +ǐ +ō +ǒ +糝 +綱 +▎ +μ +盜 +饅 +醬 +籤 +蓋 +釀 +鹽 +據 +à +ɡ +辦 +◥ +彐 +┌ +婦 +獸 +鲩 +伱 +ī +蒟 +蒻 +齊 +袆 +腦 +寧 +凈 +妳 +煥 +詢 +偽 +謹 +啫 +鯽 +騷 +鱸 +損 +傷 +鎻 +髮 +買 +冏 +儥 +両 +﹢ +∞ +載 +喰 +z +羙 +悵 +燙 +曉 +員 +組 +徹 +艷 +痠 +鋼 +鼙 +縮 +細 +嚒 +爯 +≠ +維 +" +鱻 +壇 +厍 +帰 +浥 +犇 +薡 +軎 +² +應 +醜 +刪 +緻 +鶴 +賜 +噁 +軌 +尨 +镔 +鷺 +槗 +彌 +葚 +濛 +請 +溇 +緹 +賢 +訪 +獴 +瑅 +資 +縤 +陣 +蕟 +栢 +韻 +祼 +恁 +伢 +謝 +劃 +涑 +總 +衖 +踺 +砋 +凉 +籃 +駿 +苼 +瘋 +昽 +紡 +驊 +腎 +﹗ +響 +杋 +剛 +嚴 +禪 +歓 +槍 +傘 +檸 +檫 +炣 +勢 +鏜 +鎢 +銑 +尐 +減 +奪 +惡 +θ +僮 +婭 +臘 +ū +ì +殻 +鉄 +∑ +蛲 +焼 +緖 +續 +紹 +懮 \ No newline at end of file diff --git a/custom-demo/back-end/model/anytext/utils.py b/custom-demo/back-end/model/anytext/utils.py new file mode 100644 index 0000000..c9f55b8 --- /dev/null +++ b/custom-demo/back-end/model/anytext/utils.py @@ -0,0 +1,151 @@ +import os +import datetime +import cv2 +import numpy as np +from PIL import Image, ImageDraw + + +def save_images(img_list, folder): + if not os.path.exists(folder): + os.makedirs(folder) + now = datetime.datetime.now() + date_str = now.strftime("%Y-%m-%d") + folder_path = os.path.join(folder, date_str) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + time_str = now.strftime("%H_%M_%S") + for idx, img in enumerate(img_list): + image_number = idx + 1 + filename = f"{time_str}_{image_number}.jpg" + save_path = os.path.join(folder_path, filename) + cv2.imwrite(save_path, img[..., ::-1]) + + +def check_channels(image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + +def resize_image(img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + +def insert_spaces(string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + +def draw_glyph(font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + text_width, text_height = new_font.getsize(text) + offset_x, offset_y = new_font.getoffset(text) + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - offset_y // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + +def draw_glyph2( + font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True +): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img diff --git a/custom-demo/back-end/model/base.py b/custom-demo/back-end/model/base.py new file mode 100644 index 0000000..8d79179 --- /dev/null +++ b/custom-demo/back-end/model/base.py @@ -0,0 +1,418 @@ +import abc +from typing import Optional + +import cv2 +import torch +import numpy as np +from loguru import logger + +from iopaint.helper import ( + boxes_from_mask, + resize_max_size, + pad_img_to_modulo, + switch_mps_device, +) +from iopaint.schema import InpaintRequest, HDStrategy, SDSampler +from .helper.g_diffuser_bot import expand_image, expand_image2 +from .utils import get_scheduler + + +class InpaintModel: + name = "base" + min_size: Optional[int] = None + pad_mod = 8 + pad_to_square = False + is_erase_model = False + + def __init__(self, device, **kwargs): + """ + + Args: + device: + """ + device = switch_mps_device(self.name, device) + self.device = device + self.init_model(device, **kwargs) + + @abc.abstractmethod + def init_model(self, device, **kwargs): + ... + + @staticmethod + @abc.abstractmethod + def is_downloaded() -> bool: + return False + + @abc.abstractmethod + def forward(self, image, mask, config: InpaintRequest): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W, 1] 255 为 masks 区域 + return: BGR IMAGE + """ + ... + + @staticmethod + def download(): + ... + + def _pad_forward(self, image, mask, config: InpaintRequest): + origin_height, origin_width = image.shape[:2] + pad_image = pad_img_to_modulo( + image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + pad_mask = pad_img_to_modulo( + mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + + # logger.info(f"final forward pad size: {pad_image.shape}") + + image, mask = self.forward_pre_process(image, mask, config) + + result = self.forward(pad_image, pad_mask, config) + result = result[0:origin_height, 0:origin_width, :] + + result, image, mask = self.forward_post_process(result, image, mask, config) + + if config.sd_keep_unmasked_area: + mask = mask[:, :, np.newaxis] + result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) + return result + + def forward_pre_process(self, image, mask, config): + return image, mask + + def forward_post_process(self, result, image, mask, config): + return result, image, mask + + @torch.no_grad() + def __call__(self, image, mask, config: InpaintRequest): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + inpaint_result = None + # logger.info(f"hd_strategy: {config.hd_strategy}") + if config.hd_strategy == HDStrategy.CROP: + if max(image.shape) > config.hd_strategy_crop_trigger_size: + logger.info(f"Run crop strategy") + boxes = boxes_from_mask(mask) + crop_result = [] + for box in boxes: + crop_image, crop_box = self._run_box(image, mask, box, config) + crop_result.append((crop_image, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + elif config.hd_strategy == HDStrategy.RESIZE: + if max(image.shape) > config.hd_strategy_resize_limit: + origin_size = image.shape[:2] + downsize_image = resize_max_size( + image, size_limit=config.hd_strategy_resize_limit + ) + downsize_mask = resize_max_size( + mask, size_limit=config.hd_strategy_resize_limit + ) + + logger.info( + f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" + ) + inpaint_result = self._pad_forward( + downsize_image, downsize_mask, config + ) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + original_pixel_indices = mask < 127 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + original_pixel_indices + ] + + if inpaint_result is None: + inpaint_result = self._pad_forward(image, mask, config) + + return inpaint_result + + def _crop_box(self, image, mask, box, config: InpaintRequest): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE, (l, r, r, b) + """ + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cx = (box[0] + box[2]) // 2 + cy = (box[1] + box[3]) // 2 + img_h, img_w = image.shape[:2] + + w = box_w + config.hd_strategy_crop_margin * 2 + h = box_h + config.hd_strategy_crop_margin * 2 + + _l = cx - w // 2 + _r = cx + w // 2 + _t = cy - h // 2 + _b = cy + h // 2 + + l = max(_l, 0) + r = min(_r, img_w) + t = max(_t, 0) + b = min(_b, img_h) + + # try to get more context when crop around image edge + if _l < 0: + r += abs(_l) + if _r > img_w: + l -= _r - img_w + if _t < 0: + b += abs(_t) + if _b > img_h: + t -= _b - img_h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + + # logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") + + return crop_img, crop_mask, [l, t, r, b] + + def _calculate_cdf(self, histogram): + cdf = histogram.cumsum() + normalized_cdf = cdf / float(cdf.max()) + return normalized_cdf + + def _calculate_lookup(self, source_cdf, reference_cdf): + lookup_table = np.zeros(256) + lookup_val = 0 + for source_index, source_val in enumerate(source_cdf): + for reference_index, reference_val in enumerate(reference_cdf): + if reference_val >= source_val: + lookup_val = reference_index + break + lookup_table[source_index] = lookup_val + return lookup_table + + def _match_histograms(self, source, reference, mask): + transformed_channels = [] + if len(mask.shape) == 3: + mask = mask[:, :, -1] + + for channel in range(source.shape[-1]): + source_channel = source[:, :, channel] + reference_channel = reference[:, :, channel] + + # only calculate histograms for non-masked parts + source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) + reference_histogram, _ = np.histogram( + reference_channel[mask == 0], 256, [0, 256] + ) + + source_cdf = self._calculate_cdf(source_histogram) + reference_cdf = self._calculate_cdf(reference_histogram) + + lookup = self._calculate_lookup(source_cdf, reference_cdf) + + transformed_channels.append(cv2.LUT(source_channel, lookup)) + + result = cv2.merge(transformed_channels) + result = cv2.convertScaleAbs(result) + + return result + + def _apply_cropper(self, image, mask, config: InpaintRequest): + img_h, img_w = image.shape[:2] + l, t, w, h = ( + config.croper_x, + config.croper_y, + config.croper_width, + config.croper_height, + ) + r = l + w + b = t + h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + return crop_img, crop_mask, (l, t, r, b) + + def _run_box(self, image, mask, box, config: InpaintRequest): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE + """ + crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) + + return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] + + +class DiffusionInpaintModel(InpaintModel): + def __init__(self, device, **kwargs): + self.model_info = kwargs["model_info"] + self.model_id_or_path = self.model_info.path + super().__init__(device, **kwargs) + + @torch.no_grad() + def __call__(self, image, mask, config: InpaintRequest): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + # boxes = boxes_from_mask(mask) + if config.use_croper: + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) + crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image + elif config.use_extender: + inpaint_result = self._do_outpainting(image, config) + else: + inpaint_result = self._scaled_pad_forward(image, mask, config) + + return inpaint_result + + def _do_outpainting(self, image, config: InpaintRequest): + # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数 + # 从 image 中 crop 出 outpainting 区域 + image_h, image_w = image.shape[:2] + cropper_l = config.extender_x + cropper_t = config.extender_y + cropper_r = config.extender_x + config.extender_width + cropper_b = config.extender_y + config.extender_height + image_l = 0 + image_t = 0 + image_r = image_w + image_b = image_h + + # 类似求 IOU + l = max(cropper_l, image_l) + t = max(cropper_t, image_t) + r = min(cropper_r, image_r) + b = min(cropper_b, image_b) + + assert ( + 0 <= l < r and 0 <= t < b + ), f"cropper and image not overlap, {l},{t},{r},{b}" + + cropped_image = image[t:b, l:r, :] + padding_l = max(0, image_l - cropper_l) + padding_t = max(0, image_t - cropper_t) + padding_r = max(0, cropper_r - image_r) + padding_b = max(0, cropper_b - image_b) + + expanded_image, mask_image = expand_image2( + cropped_image, + left=padding_l, + top=padding_t, + right=padding_r, + bottom=padding_b, + softness=config.sd_outpainting_softness, + space=config.sd_outpainting_space, + ) + + # 最终扩大了的 image, BGR + expanded_cropped_result_image = self._scaled_pad_forward( + expanded_image, mask_image, config + ) + + # RGB -> BGR + outpainting_image = cv2.copyMakeBorder( + image, + left=padding_l, + top=padding_t, + right=padding_r, + bottom=padding_b, + borderType=cv2.BORDER_CONSTANT, + value=0, + )[:, :, ::-1] + + # 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend + paste_t = 0 if config.extender_y < 0 else config.extender_y + paste_l = 0 if config.extender_x < 0 else config.extender_x + + outpainting_image[ + paste_t : paste_t + expanded_cropped_result_image.shape[0], + paste_l : paste_l + expanded_cropped_result_image.shape[1], + :, + ] = expanded_cropped_result_image + return outpainting_image + + def _scaled_pad_forward(self, image, mask, config: InpaintRequest): + longer_side_length = int(config.sd_scale * max(image.shape[:2])) + origin_size = image.shape[:2] + downsize_image = resize_max_size(image, size_limit=longer_side_length) + downsize_mask = resize_max_size(mask, size_limit=longer_side_length) + if config.sd_scale != 1: + logger.info( + f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" + ) + inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + + # blend result, copy from g_diffuser_bot + # mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0) + # inpaint_result = np.clip( + # inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0 + # ) + # original_pixel_indices = mask < 127 + # inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + # original_pixel_indices + # ] + return inpaint_result + + def set_scheduler(self, config: InpaintRequest): + scheduler_config = self.model.scheduler.config + sd_sampler = config.sd_sampler + if config.sd_lcm_lora and self.model_info.support_lcm_lora: + sd_sampler = SDSampler.lcm + logger.info(f"LCM Lora enabled, use {sd_sampler} sampler") + scheduler = get_scheduler(sd_sampler, scheduler_config) + self.model.scheduler = scheduler + + def forward_pre_process(self, image, mask, config): + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + + return image, mask + + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + # if config.sd_mask_blur != 0: + # k = 2 * config.sd_mask_blur + 1 + # mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask diff --git a/custom-demo/back-end/model/controlnet.py b/custom-demo/back-end/model/controlnet.py new file mode 100644 index 0000000..d52db01 --- /dev/null +++ b/custom-demo/back-end/model/controlnet.py @@ -0,0 +1,190 @@ +import PIL.Image +import cv2 +import torch +from diffusers import ControlNetModel +from loguru import logger +from iopaint.schema import InpaintRequest, ModelType + +from .base import DiffusionInpaintModel +from .helper.controlnet_preprocess import ( + make_canny_control_image, + make_openpose_control_image, + make_depth_control_image, + make_inpaint_control_image, +) +from .helper.cpu_text_encoder import CPUTextEncoderWrapper +from .original_sd_configs import get_config_files +from .utils import ( + get_scheduler, + handle_from_pretrained_exceptions, + get_torch_dtype, + enable_low_mem, + is_local_files_only, +) + + +class ControlNet(DiffusionInpaintModel): + name = "controlnet" + pad_mod = 8 + min_size = 512 + + @property + def lcm_lora_id(self): + if self.model_info.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SD_INPAINT, + ]: + return "latent-consistency/lcm-lora-sdv1-5" + if self.model_info.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + return "latent-consistency/lcm-lora-sdxl" + raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}") + + def init_model(self, device: torch.device, **kwargs): + model_info = kwargs["model_info"] + controlnet_method = kwargs["controlnet_method"] + + self.model_info = model_info + self.controlnet_method = controlnet_method + + model_kwargs = { + **kwargs.get("pipe_components", {}), + "local_files_only": is_local_files_only(**kwargs), + } + self.local_files_only = model_kwargs["local_files_only"] + + disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get( + "cpu_offload", False + ) + if disable_nsfw_checker: + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + self.torch_dtype = torch_dtype + + if model_info.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SD_INPAINT, + ]: + from diffusers import ( + StableDiffusionControlNetInpaintPipeline as PipeClass, + ) + elif model_info.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + from diffusers import ( + StableDiffusionXLControlNetInpaintPipeline as PipeClass, + ) + + controlnet = ControlNetModel.from_pretrained( + pretrained_model_name_or_path=controlnet_method, + resume_download=True, + local_files_only=model_kwargs["local_files_only"], + torch_dtype=self.torch_dtype, + ) + if model_info.is_single_file_diffusers: + if self.model_info.model_type == ModelType.DIFFUSERS_SD: + model_kwargs["num_in_channels"] = 4 + else: + model_kwargs["num_in_channels"] = 9 + + self.model = PipeClass.from_single_file( + model_info.path, + controlnet=controlnet, + load_safety_checker=not disable_nsfw_checker, + torch_dtype=torch_dtype, + config_files=get_config_files(), + **model_kwargs, + ) + else: + self.model = handle_from_pretrained_exceptions( + PipeClass.from_pretrained, + pretrained_model_name_or_path=model_info.path, + controlnet=controlnet, + variant="fp16", + torch_dtype=torch_dtype, + **model_kwargs, + ) + + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def switch_controlnet_method(self, new_method: str): + self.controlnet_method = new_method + controlnet = ControlNetModel.from_pretrained( + new_method, + resume_download=True, + local_files_only=self.local_files_only, + torch_dtype=self.torch_dtype, + ).to(self.model.device) + self.model.controlnet = controlnet + + def _get_control_image(self, image, mask): + if "canny" in self.controlnet_method: + control_image = make_canny_control_image(image) + elif "openpose" in self.controlnet_method: + control_image = make_openpose_control_image(image) + elif "depth" in self.controlnet_method: + control_image = make_depth_control_image(image) + elif "inpaint" in self.controlnet_method: + control_image = make_inpaint_control_image(image, mask) + else: + raise NotImplementedError(f"{self.controlnet_method} not implemented") + return control_image + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + scheduler_config = self.model.scheduler.config + scheduler = get_scheduler(config.sd_sampler, scheduler_config) + self.model.scheduler = scheduler + + img_h, img_w = image.shape[:2] + control_image = self._get_control_image(image, mask) + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") + image = PIL.Image.fromarray(image) + + output = self.model( + image=image, + mask_image=mask_image, + control_image=control_image, + prompt=config.prompt, + negative_prompt=config.negative_prompt, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback_on_step_end=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/custom-demo/back-end/model/ddim_sampler.py b/custom-demo/back-end/model/ddim_sampler.py new file mode 100644 index 0000000..a3f44fd --- /dev/null +++ b/custom-demo/back-end/model/ddim_sampler.py @@ -0,0 +1,193 @@ +import torch +import numpy as np +from tqdm import tqdm + +from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like + +from loguru import logger + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear"): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + # array([1]) + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample(self, steps, conditioning, batch_size, shape): + self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # samples: 1,3,128,128 + return self.ddim_sampling( + conditioning, + size, + quantize_denoised=False, + ddim_use_original_steps=False, + noise_dropout=0, + temperature=1.0, + ) + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + ddim_use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + device = self.model.betas.device + b = shape[0] + img = torch.randn(shape, device=device, dtype=cond.dtype) + timesteps = ( + self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + ) + + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + logger.info(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + ) + img, _ = outs + + return img + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + b, *_, device = *x.shape, x.device + e_t = self.model.apply_model(x, t, c) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: # 没用 + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: # 没用 + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/custom-demo/back-end/model/fcf.py b/custom-demo/back-end/model/fcf.py new file mode 100644 index 0000000..a6f2d42 --- /dev/null +++ b/custom-demo/back-end/model/fcf.py @@ -0,0 +1,1737 @@ +import os +import random + +import cv2 +import torch +import numpy as np +import torch.fft as fft + +from iopaint.schema import InpaintRequest + +from iopaint.helper import ( + load_model, + get_cache_path_by_url, + norm_img, + boxes_from_mask, + resize_max_size, + download_model, +) +from .base import InpaintModel +from torch import conv2d, nn +import torch.nn.functional as F + +from .utils import ( + setup_filter, + _parse_scaling, + _parse_padding, + Conv2dLayer, + FullyConnectedLayer, + MinibatchStdLayer, + activation_funcs, + conv2d_resample, + bias_act, + upsample2d, + normalize_2nd_moment, + downsample2d, +) + + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): + assert isinstance(x, torch.Tensor) + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) + + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad( + x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] + ) + x = x[ + :, + :, + max(-pady0, 0) : x.shape[2] - max(-pady1, 0), + max(-padx0, 0) : x.shape[3] - max(-padx1, 0), + ] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +class EncoderEpilogue(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ["orig", "skip", "resnet"] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == "skip": + self.fromrgb = Conv2dLayer( + self.img_channels, in_channels, kernel_size=1, activation=activation + ) + self.mbstd = ( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + if mbstd_num_channels > 0 + else None + ) + self.conv = Conv2dLayer( + in_channels + mbstd_num_channels, + in_channels, + kernel_size=3, + activation=activation, + conv_clamp=conv_clamp, + ) + self.fc = FullyConnectedLayer( + in_channels * (resolution**2), z_dim, activation=activation + ) + self.dropout = torch.nn.Dropout(p=0.5) + + def forward(self, x, cmap, force_fp32=False): + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + const_e = self.conv(x) + x = self.fc(const_e.flatten(1)) + x = self.dropout(x) + + # Conditioning. + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x, const_e + + +class EncoderBlock(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + freeze_layers=0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ["orig", "skip", "resnet"] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + 1 + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = use_fp16 and fp16_channels_last + self.register_buffer("resample_filter", setup_filter(resample_filter)) + + self.num_layers = 0 + + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = layer_idx >= freeze_layers + self.num_layers += 1 + yield trainable + + trainable_iter = trainable_gen() + + if in_channels == 0: + self.fromrgb = Conv2dLayer( + self.img_channels, + tmp_channels, + kernel_size=1, + activation=activation, + trainable=next(trainable_iter), + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + + self.conv0 = Conv2dLayer( + tmp_channels, + tmp_channels, + kernel_size=3, + activation=activation, + trainable=next(trainable_iter), + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + + self.conv1 = Conv2dLayer( + tmp_channels, + out_channels, + kernel_size=3, + activation=activation, + down=2, + trainable=next(trainable_iter), + resample_filter=resample_filter, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + + if architecture == "resnet": + self.skip = Conv2dLayer( + tmp_channels, + out_channels, + kernel_size=1, + bias=False, + down=2, + trainable=next(trainable_iter), + resample_filter=resample_filter, + channels_last=self.channels_last, + ) + + def forward(self, x, img, force_fp32=False): + # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + dtype = torch.float32 + memory_format = ( + torch.channels_last + if self.channels_last and not force_fp32 + else torch.contiguous_format + ) + + # Input. + if x is not None: + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0: + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = ( + downsample2d(img, self.resample_filter) + if self.architecture == "skip" + else None + ) + + # Main layers. + if self.architecture == "resnet": + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + feat = x.clone() + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + feat = x.clone() + x = self.conv1(x) + + assert x.dtype == dtype + return x, img, feat + + +class EncoderNetwork(torch.nn.Module): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + z_dim, # Input latent (Z) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture="orig", # Architecture: 'orig', 'skip', 'resnet'. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for EncoderEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.z_dim = z_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [ + 2**i for i in range(self.img_resolution_log2, 2, -1) + ] + channels_dict = { + res: min(channel_base // res, channel_max) + for res in self.block_resolutions + [4] + } + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict( + img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp + ) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = res >= fp16_resolution + use_fp16 = False + block = EncoderBlock( + in_channels, + tmp_channels, + out_channels, + resolution=res, + first_layer_idx=cur_layer_idx, + use_fp16=use_fp16, + **block_kwargs, + **common_kwargs, + ) + setattr(self, f"b{res}", block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork( + z_dim=0, + c_dim=c_dim, + w_dim=cmap_dim, + num_ws=None, + w_avg_beta=None, + **mapping_kwargs, + ) + self.b4 = EncoderEpilogue( + channels_dict[4], + cmap_dim=cmap_dim, + z_dim=z_dim * 2, + resolution=4, + **epilogue_kwargs, + **common_kwargs, + ) + + def forward(self, img, c, **block_kwargs): + x = None + feats = {} + for res in self.block_resolutions: + block = getattr(self, f"b{res}") + x, img, feat = block(x, img, **block_kwargs) + feats[res] = feat + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x, const_e = self.b4(x, cmap) + feats[4] = const_e + + B, _ = x.shape + z = torch.zeros( + (B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device + ) ## Noise for Co-Modulation + return x, z, feats + + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [ + i + for i in range(x.ndim) + if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) + ] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims + 1 :]) + assert x.shape == shape + return x + + +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + padding=0, # Padding with respect to the upsampled image. + resample_filter=None, + # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate=True, # Apply weight demodulation? + flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * ( + 1 + / np.sqrt(in_channels * kh * kw) + / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True) + ) # max_Ikk + styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample( + x=x, + w=weight.to(x.dtype), + f=resample_filter, + up=up, + down=down, + padding=padding, + flip_weight=flip_weight, + ) + if demodulate and noise is not None: + x = fma( + x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype) + ) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + batch_size = int(batch_size) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample( + x=x, + w=w.to(x.dtype), + f=resample_filter, + up=up, + down=down, + padding=padding, + groups=batch_size, + flip_weight=flip_weight, + ) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +class SynthesisLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + ): + super().__init__() + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + ) + if use_noise: + self.register_buffer("noise_const", torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1): + assert noise_mode in ["random", "const", "none"] + in_resolution = self.resolution // self.up + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == "random": + noise = ( + torch.randn( + [x.shape[0], 1, self.resolution, self.resolution], device=x.device + ) + * self.noise_strength + ) + if self.use_noise and noise_mode == "const": + noise = self.noise_const * self.noise_strength + + flip_weight = self.up == 1 # slightly faster + x = modulated_conv2d( + x=x, + weight=self.weight, + styles=styles, + noise=noise, + up=self.up, + padding=self.padding, + resample_filter=self.resample_filter, + flip_weight=flip_weight, + fused_modconv=fused_modconv, + ) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = F.leaky_relu(x, negative_slope=0.2, inplace=False) + if act_gain != 1: + x = x * act_gain + if act_clamp is not None: + x = x.clamp(-act_clamp, act_clamp) + return x + + +class ToRGBLayer(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + w_dim, + kernel_size=1, + conv_clamp=None, + channels_last=False, + ): + super().__init__() + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + ) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d( + x=x, + weight=self.weight, + styles=styles, + demodulate=False, + fused_modconv=fused_modconv, + ) + x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + +class SynthesisForeword(torch.nn.Module): + def __init__( + self, + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + in_channels, + img_channels, # Number of input color channels. + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + ): + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + self.fc = FullyConnectedLayer( + self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation + ) + self.conv = SynthesisLayer( + self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4 + ) + + if architecture == "skip": + self.torgb = ToRGBLayer( + self.in_channels, + self.img_channels, + kernel_size=1, + w_dim=(z_dim // 2) * 3, + ) + + def forward(self, x, ws, feats, img, force_fp32=False): + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + x_global = x.clone() + # ToRGB. + x = self.fc(x) + x = x.view(-1, self.z_dim // 2, 4, 4) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + x_skip = feats[4].clone() + x = x + x_skip + + mod_vector = [] + mod_vector.append(ws[:, 0]) + mod_vector.append(x_global.clone()) + mod_vector = torch.cat(mod_vector, dim=1) + + x = self.conv(x, mod_vector) + + mod_vector = [] + mod_vector.append(ws[:, 2 * 2 - 3]) + mod_vector.append(x_global.clone()) + mod_vector = torch.cat(mod_vector, dim=1) + + if self.architecture == "skip": + img = self.torgb(x, mod_vector) + img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + assert x.dtype == dtype + return x, img + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=False), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FourierUnit(nn.Module): + def __init__( + self, + in_channels, + out_channels, + groups=1, + spatial_scale_factor=None, + spatial_scale_mode="bilinear", + spectral_pos_encoding=False, + use_se=False, + se_kwargs=None, + ffc3d=False, + fft_norm="ortho", + ): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d( + in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + groups=self.groups, + bias=False, + ) + self.relu = torch.nn.ReLU(inplace=False) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate( + x, + scale_factor=self.spatial_scale_factor, + mode=self.spatial_scale_mode, + align_corners=False, + ) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view( + ( + batch, + -1, + ) + + ffted.size()[3:] + ) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = ( + torch.linspace(0, 1, height)[None, None, :, None] + .expand(batch, 1, height, width) + .to(ffted) + ) + coords_hor = ( + torch.linspace(0, 1, width)[None, None, None, :] + .expand(batch, 1, height, width) + .to(ffted) + ) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(ffted) + + ffted = ( + ffted.view( + ( + batch, + -1, + 2, + ) + + ffted.size()[2:] + ) + .permute(0, 1, 3, 4, 2) + .contiguous() + ) # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn( + ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm + ) + + if self.spatial_scale_factor is not None: + output = F.interpolate( + output, + size=orig_size, + mode=self.spatial_scale_mode, + align_corners=False, + ) + + return output + + +class SpectralTransform(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride=1, + groups=1, + enable_lfu=True, + **fu_kwargs, + ): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False + ), + # nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True), + ) + self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False + ) + + def forward(self, x): + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat( + torch.split(x[:, : c // 4], split_s, dim=-2), dim=1 + ).contiguous() + xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class FFC(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + enable_lfu=True, + padding_type="reflect", + gated=False, + **spectral_kwargs, + ): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + # groups_g = 1 if groups == 1 else int(groups * ratio_gout) + # groups_l = 1 if groups == 1 else groups - groups_g + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module( + in_cl, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module( + in_cl, + out_cg, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module( + in_cg, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, + out_cg, + stride, + 1 if groups == 1 else groups // 2, + enable_lfu, + **spectral_kwargs, + ) + + self.gated = gated + module = ( + nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + ) + self.gate = module(in_channels, 2, 1) + + def forward(self, x, fname=None): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + spec_x = self.convg2g(x_g) + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + spec_x + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_layer=nn.SyncBatchNorm, + activation_layer=nn.Identity, + padding_type="reflect", + enable_lfu=True, + **kwargs, + ): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC( + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride, + padding, + dilation, + groups, + bias, + enable_lfu, + padding_type=padding_type, + **kwargs, + ) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + # self.bn_l = lnorm(out_channels - global_channels) + # self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x, fname=None): + x_l, x_g = self.ffc( + x, + fname=fname, + ) + x_l = self.act_l(x_l) + x_g = self.act_g(x_g) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__( + self, + dim, + padding_type, + norm_layer, + activation_layer=nn.ReLU, + dilation=1, + spatial_transform_kwargs=None, + inline=False, + ratio_gin=0.75, + ratio_gout=0.75, + ): + super().__init__() + self.conv1 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + self.conv2 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + self.inline = inline + + def forward(self, x, fname=None): + if self.inline: + x_l, x_g = ( + x[:, : -self.conv1.ffc.global_in_num], + x[:, -self.conv1.ffc.global_in_num :], + ) + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g), fname=fname) + x_l, x_g = self.conv2((x_l, x_g), fname=fname) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCBlock(torch.nn.Module): + def __init__( + self, + dim, # Number of output/input channels. + kernel_size, # Width and height of the convolution kernel. + padding, + ratio_gin=0.75, + ratio_gout=0.75, + activation="linear", # Activation function: 'relu', 'lrelu', etc. + ): + super().__init__() + if activation == "linear": + self.activation = nn.Identity + else: + self.activation = nn.ReLU + self.padding = padding + self.kernel_size = kernel_size + self.ffc_block = FFCResnetBlock( + dim=dim, + padding_type="reflect", + norm_layer=nn.SyncBatchNorm, + activation_layer=self.activation, + dilation=1, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + + self.concat_layer = ConcatTupleLayer() + + def forward(self, gen_ft, mask, fname=None): + x = gen_ft.float() + + x_l, x_g = ( + x[:, : -self.ffc_block.conv1.ffc.global_in_num], + x[:, -self.ffc_block.conv1.ffc.global_in_num :], + ) + id_l, id_g = x_l, x_g + + x_l, x_g = self.ffc_block((x_l, x_g), fname=fname) + x_l, x_g = id_l + x_l, id_g + x_g + x = self.concat_layer((x_l, x_g)) + + return x + gen_ft.float() + + +class FFCSkipLayer(torch.nn.Module): + def __init__( + self, + dim, # Number of input/output channels. + kernel_size=3, # Convolution kernel size. + ratio_gin=0.75, + ratio_gout=0.75, + ): + super().__init__() + self.padding = kernel_size // 2 + + self.ffc_act = FFCBlock( + dim=dim, + kernel_size=kernel_size, + activation=nn.ReLU, + padding=self.padding, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + + def forward(self, gen_ft, mask, fname=None): + x = self.ffc_act(gen_ft, mask, fname=fname) + return x + + +class SynthesisBlock(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ["orig", "skip", "resnet"] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = use_fp16 and fp16_channels_last + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1} + + if in_channels != 0 and resolution >= 8: + self.ffc_skip = nn.ModuleList() + for _ in range(self.res_ffc[resolution]): + self.ffc_skip.append(FFCSkipLayer(dim=out_channels)) + + if in_channels == 0: + self.const = torch.nn.Parameter( + torch.randn([out_channels, resolution, resolution]) + ) + + if in_channels != 0: + self.conv0 = SynthesisLayer( + in_channels, + out_channels, + w_dim=w_dim * 3, + resolution=resolution, + up=2, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + **layer_kwargs, + ) + self.num_conv += 1 + + self.conv1 = SynthesisLayer( + out_channels, + out_channels, + w_dim=w_dim * 3, + resolution=resolution, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + **layer_kwargs, + ) + self.num_conv += 1 + + if is_last or architecture == "skip": + self.torgb = ToRGBLayer( + out_channels, + img_channels, + w_dim=w_dim * 3, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + self.num_torgb += 1 + + if in_channels != 0 and architecture == "resnet": + self.skip = Conv2dLayer( + in_channels, + out_channels, + kernel_size=1, + bias=False, + up=2, + resample_filter=resample_filter, + channels_last=self.channels_last, + ) + + def forward( + self, + x, + mask, + feats, + img, + ws, + fname=None, + force_fp32=False, + fused_modconv=None, + **layer_kwargs, + ): + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + dtype = torch.float32 + memory_format = ( + torch.channels_last + if self.channels_last and not force_fp32 + else torch.contiguous_format + ) + if fused_modconv is None: + fused_modconv = (not self.training) and ( + dtype == torch.float32 or int(x.shape[0]) == 1 + ) + + x = x.to(dtype=dtype, memory_format=memory_format) + x_skip = ( + feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format) + ) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == "resnet": + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0( + x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) + if len(self.ffc_skip) > 0: + mask = F.interpolate( + mask, + size=x_skip.shape[2:], + ) + z = x + x_skip + for fres in self.ffc_skip: + z = fres(z, mask) + x = x + z + else: + x = x + x_skip + x = self.conv1( + x, + ws[1].clone(), + fused_modconv=fused_modconv, + gain=np.sqrt(0.5), + **layer_kwargs, + ) + x = y.add_(x) + else: + x = self.conv0( + x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) + if len(self.ffc_skip) > 0: + mask = F.interpolate( + mask, + size=x_skip.shape[2:], + ) + z = x + x_skip + for fres in self.ffc_skip: + z = fres(z, mask) + x = x + z + else: + x = x + x_skip + x = self.conv1( + x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) + # ToRGB. + if img is not None: + img = upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == "skip": + y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + x = x.to(dtype=dtype) + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + +class SynthesisNetwork(torch.nn.Module): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + z_dim, # Output Latent (Z) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [ + 2**i for i in range(3, self.img_resolution_log2 + 1) + ] + channels_dict = { + res: min(channel_base // res, channel_max) for res in self.block_resolutions + } + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.foreword = SynthesisForeword( + img_channels=img_channels, + in_channels=min(channel_base // 4, channel_max), + z_dim=z_dim * 2, + resolution=4, + ) + + self.num_ws = self.img_resolution_log2 * 2 - 2 + for res in self.block_resolutions: + if res // 2 in channels_dict.keys(): + in_channels = channels_dict[res // 2] if res > 4 else 0 + else: + in_channels = min(channel_base // (res // 2), channel_max) + out_channels = channels_dict[res] + use_fp16 = res >= fp16_resolution + use_fp16 = False + is_last = res == self.img_resolution + block = SynthesisBlock( + in_channels, + out_channels, + w_dim=w_dim, + resolution=res, + img_channels=img_channels, + is_last=is_last, + use_fp16=use_fp16, + **block_kwargs, + ) + setattr(self, f"b{res}", block) + + def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): + img = None + + x, img = self.foreword(x_global, ws, feats, img) + + for res in self.block_resolutions: + block = getattr(self, f"b{res}") + mod_vector0 = [] + mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5]) + mod_vector0.append(x_global.clone()) + mod_vector0 = torch.cat(mod_vector0, dim=1) + + mod_vector1 = [] + mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4]) + mod_vector1.append(x_global.clone()) + mod_vector1 = torch.cat(mod_vector1, dim=1) + + mod_vector_rgb = [] + mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3]) + mod_vector_rgb.append(x_global.clone()) + mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1) + x, img = block( + x, + mask, + feats, + img, + (mod_vector0, mod_vector1, mod_vector_rgb), + fname=fname, + **block_kwargs, + ) + return img + + +class MappingNetwork(torch.nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = ( + [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + ) + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer( + in_features, + out_features, + activation=activation, + lr_multiplier=lr_multiplier, + ) + setattr(self, f"fc{idx}", layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer("w_avg", torch.zeros([w_dim])) + + def forward( + self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False + ): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function("input"): + if self.z_dim > 0: + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f"fc{idx}") + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + with torch.autograd.profiler.record_function("update_w_avg"): + self.w_avg.copy_( + x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta) + ) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function("broadcast"): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function("truncate"): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) + return x + + +class Generator(torch.nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + encoder_kwargs={}, # Arguments for EncoderNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.encoder = EncoderNetwork( + c_dim=c_dim, + z_dim=z_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **encoder_kwargs, + ) + self.synthesis = SynthesisNetwork( + z_dim=z_dim, + w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs, + ) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork( + z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs + ) + + def forward( + self, + img, + c, + fname=None, + truncation_psi=1, + truncation_cutoff=None, + **synthesis_kwargs, + ): + mask = img[:, -1].unsqueeze(1) + x_global, z, feats = self.encoder(img, c) + ws = self.mapping( + z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff + ) + img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) + return img + + +FCF_MODEL_URL = os.environ.get( + "FCF_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth", +) +FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211") + + +class FcF(InpaintModel): + name = "fcf" + min_size = 512 + pad_mod = 512 + pad_to_square = True + is_erase_model = True + + def init_model(self, device, **kwargs): + seed = 0 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + kwargs = { + "channel_base": 1 * 32768, + "channel_max": 512, + "num_fp16_res": 4, + "conv_clamp": 256, + } + G = Generator( + z_dim=512, + c_dim=0, + w_dim=512, + img_resolution=512, + img_channels=3, + synthesis_kwargs=kwargs, + encoder_kwargs=kwargs, + mapping_kwargs={"num_layers": 2}, + ) + self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5) + self.label = torch.zeros([1, self.model.c_dim], device=device) + + @staticmethod + def download(): + download_model(FCF_MODEL_URL, FCF_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL)) + + @torch.no_grad() + def __call__(self, image, mask, config: InpaintRequest): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if image.shape[0] == 512 and image.shape[1] == 512: + return self._pad_forward(image, mask, config) + + boxes = boxes_from_mask(mask) + crop_result = [] + config.hd_strategy_crop_margin = 128 + for box in boxes: + crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config) + origin_size = crop_image.shape[:2] + resize_image = resize_max_size(crop_image, size_limit=512) + resize_mask = resize_max_size(crop_mask, size_limit=512) + inpaint_result = self._pad_forward(resize_image, resize_mask, config) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + + original_pixel_indices = crop_mask < 127 + inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][ + original_pixel_indices + ] + + crop_result.append((inpaint_result, crop_box)) + + inpaint_result = image[:, :, ::-1].copy() + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + return inpaint_result + + def forward(self, image, mask, config: InpaintRequest): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + mask = (mask > 120) * 255 + mask = norm_img(mask) + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + erased_img = image * (1 - mask) + input_image = torch.cat([0.5 - mask, erased_img], dim=1) + + output = self.model( + input_image, self.label, truncation_psi=0.1, noise_mode="none" + ) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/custom-demo/back-end/model/helper/__init__.py b/custom-demo/back-end/model/helper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/helper/controlnet_preprocess.py b/custom-demo/back-end/model/helper/controlnet_preprocess.py new file mode 100644 index 0000000..75c409f --- /dev/null +++ b/custom-demo/back-end/model/helper/controlnet_preprocess.py @@ -0,0 +1,68 @@ +import torch +import PIL +import cv2 +from PIL import Image +import numpy as np + +from iopaint.helper import pad_img_to_modulo + + +def make_canny_control_image(image: np.ndarray) -> Image: + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) + canny_image = PIL.Image.fromarray(canny_image) + control_image = canny_image + return control_image + + +def make_openpose_control_image(image: np.ndarray) -> Image: + from controlnet_aux import OpenposeDetector + + processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + control_image = processor(image, hand_and_face=True) + return control_image + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize( + input_image, + (W, H), + interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA, + ) + return img + + +def make_depth_control_image(image: np.ndarray) -> Image: + from controlnet_aux import MidasDetector + + midas = MidasDetector.from_pretrained("lllyasviel/Annotators") + + origin_height, origin_width = image.shape[:2] + pad_image = pad_img_to_modulo(image, mod=64, square=False, min_size=512) + depth_image = midas(pad_image) + depth_image = depth_image[0:origin_height, 0:origin_width] + depth_image = depth_image[:, :, None] + depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2) + control_image = PIL.Image.fromarray(depth_image) + return control_image + + +def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor: + """ + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + """ + image = image.astype(np.float32) / 255.0 + image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image diff --git a/custom-demo/back-end/model/helper/cpu_text_encoder.py b/custom-demo/back-end/model/helper/cpu_text_encoder.py new file mode 100644 index 0000000..116eb48 --- /dev/null +++ b/custom-demo/back-end/model/helper/cpu_text_encoder.py @@ -0,0 +1,41 @@ +import torch +from transformers import PreTrainedModel + +from ..utils import torch_gc + + +class CPUTextEncoderWrapper(PreTrainedModel): + def __init__(self, text_encoder, torch_dtype): + super().__init__(text_encoder.config) + self.config = text_encoder.config + self._device = text_encoder.device + # cpu not support float16 + self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) + self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) + self.torch_dtype = torch_dtype + del text_encoder + torch_gc() + + def __call__(self, x, **kwargs): + input_device = x.device + original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs) + for k, v in original_output.items(): + if isinstance(v, tuple): + original_output[k] = [ + v[i].to(input_device).to(self.torch_dtype) for i in range(len(v)) + ] + else: + original_output[k] = v.to(input_device).to(self.torch_dtype) + return original_output + + @property + def dtype(self): + return self.torch_dtype + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return self._device \ No newline at end of file diff --git a/custom-demo/back-end/model/helper/g_diffuser_bot.py b/custom-demo/back-end/model/helper/g_diffuser_bot.py new file mode 100644 index 0000000..f669f9a --- /dev/null +++ b/custom-demo/back-end/model/helper/g_diffuser_bot.py @@ -0,0 +1,207 @@ +# code copy from: https://github.com/parlance-zz/g-diffuser-bot +import cv2 +import numpy as np + + +def np_img_grey_to_rgb(data): + if data.ndim == 3: + return data + return np.expand_dims(data, 2) * np.ones((1, 1, 3)) + + +def convolve(data1, data2): # fast convolution with fft + if data1.ndim != data2.ndim: # promote to rgb if mismatch + if data1.ndim < 3: + data1 = np_img_grey_to_rgb(data1) + if data2.ndim < 3: + data2 = np_img_grey_to_rgb(data2) + return ifft2(fft2(data1) * fft2(data2)) + + +def fft2(data): + if data.ndim > 2: # multiple channels + out_fft = np.zeros( + (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128 + ) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho") + out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) + else: # single channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") + out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) + + return out_fft + + +def ifft2(data): + if data.ndim > 2: # multiple channels + out_ifft = np.zeros( + (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128 + ) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho") + out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) + else: # single channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") + out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) + + return out_ifft + + +def get_gradient_kernel(width, height, std=3.14, mode="linear"): + window_scale_x = float( + width / min(width, height) + ) # for non-square aspect ratios we still want a circular kernel + window_scale_y = float(height / min(width, height)) + if mode == "gaussian": + x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x + kx = np.exp(-x * x * std) + if window_scale_x != window_scale_y: + y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y + ky = np.exp(-y * y * std) + else: + y = x + ky = kx + return np.outer(kx, ky) + elif mode == "linear": + x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x + if window_scale_x != window_scale_y: + y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y + else: + y = x + return np.clip(1.0 - np.sqrt(np.add.outer(x * x, y * y)) * std / 3.14, 0.0, 1.0) + else: + raise Exception("Error: Unknown mode in get_gradient_kernel: {0}".format(mode)) + + +def image_blur(data, std=3.14, mode="linear"): + width = data.shape[0] + height = data.shape[1] + kernel = get_gradient_kernel(width, height, std, mode=mode) + return np.real(convolve(data, kernel / np.sqrt(np.sum(kernel * kernel)))) + + +def soften_mask(mask_img, softness, space): + if softness == 0: + return mask_img + softness = min(softness, 1.0) + space = np.clip(space, 0.0, 1.0) + original_max_opacity = np.max(mask_img) + out_mask = mask_img <= 0.0 + blurred_mask = image_blur(mask_img, 3.5 / softness, mode="linear") + blurred_mask = np.maximum(blurred_mask - np.max(blurred_mask[out_mask]), 0.0) + mask_img *= blurred_mask # preserve partial opacity in original input mask + mask_img /= np.max(mask_img) # renormalize + mask_img = np.clip(mask_img - space, 0.0, 1.0) # make space + mask_img /= np.max(mask_img) # and renormalize again + mask_img *= original_max_opacity # restore original max opacity + return mask_img + + +def expand_image( + cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float +): + assert cv2_img.shape[2] == 3 + origin_h, origin_w = cv2_img.shape[:2] + new_width = cv2_img.shape[1] + left + right + new_height = cv2_img.shape[0] + top + bottom + + # TODO: which is better? + # new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8) + new_img = cv2.copyMakeBorder( + cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE + ) + mask_img = np.zeros((new_height, new_width), np.uint8) + mask_img[top: top + cv2_img.shape[0], left: left + cv2_img.shape[1]] = 255 + + if softness > 0.0: + mask_img = soften_mask(mask_img / 255.0, softness / 100.0, space / 100.0) + mask_img = (np.clip(mask_img, 0.0, 1.0) * 255.0).astype(np.uint8) + + mask_image = 255.0 - mask_img # extract mask from alpha channel and invert + rgb_init_image = ( + 0.0 + new_img[:, :, 0:3] + ) # strip mask from init_img leaving only rgb channels + + hard_mask = np.zeros_like(cv2_img[:, :, 0]) + if top != 0: + hard_mask[0: origin_h // 2, :] = 255 + if bottom != 0: + hard_mask[origin_h // 2:, :] = 255 + if left != 0: + hard_mask[:, 0: origin_w // 2] = 255 + if right != 0: + hard_mask[:, origin_w // 2:] = 255 + + hard_mask = cv2.copyMakeBorder( + hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255 + ) + mask_image = np.where(hard_mask > 0, mask_image, 0) + return rgb_init_image.astype(np.uint8), mask_image.astype(np.uint8) + + +def expand_image2( + cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float +): + assert cv2_img.shape[2] == 3 + origin_h, origin_w = cv2_img.shape[:2] + new_width = cv2_img.shape[1] + left + right + new_height = cv2_img.shape[0] + top + bottom + + # TODO: which is better? + # new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8) + new_img = cv2.copyMakeBorder( + cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE + ) + + inner_padding_left = 13 if left > 0 else 0 + inner_padding_right = 13 if right > 0 else 0 + inner_padding_top = 13 if top > 0 else 0 + inner_padding_bottom = 13 if bottom > 0 else 0 + + mask_image = np.zeros( + ( + origin_h - inner_padding_top - inner_padding_bottom + , origin_w - inner_padding_left - inner_padding_right + ), + np.uint8) + mask_image = cv2.copyMakeBorder( + mask_image, + top + inner_padding_top, + bottom + inner_padding_bottom, + left + inner_padding_left, + right + inner_padding_right, + cv2.BORDER_CONSTANT, + value=255 + ) + # k = 2*int(min(origin_h, origin_w) // 6)+1 + k = 7 + mask_image = cv2.GaussianBlur(mask_image, (k, k), 0) + return new_img, mask_image + + +if __name__ == "__main__": + from pathlib import Path + + current_dir = Path(__file__).parent.absolute().resolve() + image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg" + init_image = cv2.imread(str(image_path)) + init_image, mask_image = expand_image2( + init_image, + top=0, + right=0, + bottom=0, + left=100, + softness=20, + space=20, + ) + print(mask_image.dtype, mask_image.min(), mask_image.max()) + print(init_image.dtype, init_image.min(), init_image.max()) + mask_image = mask_image.astype(np.uint8) + init_image = init_image.astype(np.uint8) + cv2.imwrite("expanded_image.png", init_image) + cv2.imwrite("expanded_mask.png", mask_image) diff --git a/custom-demo/back-end/model/instruct_pix2pix.py b/custom-demo/back-end/model/instruct_pix2pix.py new file mode 100644 index 0000000..fc8cd26 --- /dev/null +++ b/custom-demo/back-end/model/instruct_pix2pix.py @@ -0,0 +1,64 @@ +import PIL.Image +import cv2 +import torch +from loguru import logger + +from iopaint.const import INSTRUCT_PIX2PIX_NAME +from .base import DiffusionInpaintModel +from iopaint.schema import InpaintRequest +from .utils import get_torch_dtype, enable_low_mem, is_local_files_only + + +class InstructPix2Pix(DiffusionInpaintModel): + name = INSTRUCT_PIX2PIX_NAME + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers import StableDiffusionInstructPix2PixPipeline + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + + model_kwargs = {"local_files_only": is_local_files_only(**kwargs)} + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( + self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs + ) + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0] + """ + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + num_inference_steps=config.sd_steps, + image_guidance_scale=config.p2p_image_guidance_scale, + guidance_scale=config.sd_guidance_scale, + output_type="np", + generator=torch.manual_seed(config.sd_seed), + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/custom-demo/back-end/model/kandinsky.py b/custom-demo/back-end/model/kandinsky.py new file mode 100644 index 0000000..1a0bf1c --- /dev/null +++ b/custom-demo/back-end/model/kandinsky.py @@ -0,0 +1,65 @@ +import PIL.Image +import cv2 +import numpy as np +import torch + +from iopaint.const import KANDINSKY22_NAME +from .base import DiffusionInpaintModel +from iopaint.schema import InpaintRequest +from .utils import get_torch_dtype, enable_low_mem, is_local_files_only + + +class Kandinsky(DiffusionInpaintModel): + pad_mod = 64 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers import AutoPipelineForInpainting + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + + model_kwargs = { + "torch_dtype": torch_dtype, + "local_files_only": is_local_files_only(**kwargs), + } + self.model = AutoPipelineForInpainting.from_pretrained( + self.name, **model_kwargs + ).to(device) + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + self.set_scheduler(config) + + generator = torch.manual_seed(config.sd_seed) + mask = mask.astype(np.float32) / 255 + img_h, img_w = image.shape[:2] + + # kandinsky 没有 strength + output = self.model( + prompt=config.prompt, + negative_prompt=config.negative_prompt, + image=PIL.Image.fromarray(image), + mask_image=mask[:, :, 0], + height=img_h, + width=img_w, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback_on_step_end=self.callback, + generator=generator, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +class Kandinsky22(Kandinsky): + name = KANDINSKY22_NAME diff --git a/custom-demo/back-end/model/lama.py b/custom-demo/back-end/model/lama.py new file mode 100644 index 0000000..7aba242 --- /dev/null +++ b/custom-demo/back-end/model/lama.py @@ -0,0 +1,57 @@ +import os + +import cv2 +import numpy as np +import torch + +from iopaint.helper import ( + norm_img, + get_cache_path_by_url, + load_jit_model, + download_model, +) +from iopaint.schema import InpaintRequest +from .base import InpaintModel + +LAMA_MODEL_URL = os.environ.get( + "LAMA_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", +) +LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500") + + +class LaMa(InpaintModel): + name = "lama" + pad_mod = 8 + is_erase_model = True + + @staticmethod + def download(): + download_model(LAMA_MODEL_URL, LAMA_MODEL_MD5) + + def init_model(self, device, **kwargs): + self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval() + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W] + return: BGR IMAGE + """ + image = norm_img(image) + mask = norm_img(mask) + + mask = (mask > 0) * 1 + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + inpainted_image = self.model(image, mask) + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/custom-demo/back-end/model/ldm.py b/custom-demo/back-end/model/ldm.py new file mode 100644 index 0000000..19e51a3 --- /dev/null +++ b/custom-demo/back-end/model/ldm.py @@ -0,0 +1,336 @@ +import os + +import numpy as np +import torch +from loguru import logger + +from .base import InpaintModel +from .ddim_sampler import DDIMSampler +from .plms_sampler import PLMSSampler +from iopaint.schema import InpaintRequest, LDMSampler + +torch.manual_seed(42) +import torch.nn as nn +from iopaint.helper import ( + download_model, + norm_img, + get_cache_path_by_url, + load_jit_model, +) +from .utils import ( + make_beta_schedule, + timestep_embedding, +) + +LDM_ENCODE_MODEL_URL = os.environ.get( + "LDM_ENCODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", +) +LDM_ENCODE_MODEL_MD5 = os.environ.get( + "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296" +) + +LDM_DECODE_MODEL_URL = os.environ.get( + "LDM_DECODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", +) +LDM_DECODE_MODEL_MD5 = os.environ.get( + "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c" +) + +LDM_DIFFUSION_MODEL_URL = os.environ.get( + "LDM_DIFFUSION_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", +) + +LDM_DIFFUSION_MODEL_MD5 = os.environ.get( + "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d" +) + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + device, + timesteps=1000, + beta_schedule="linear", + linear_start=0.0015, + linear_end=0.0205, + cosine_s=0.008, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + parameterization="eps", # all assuming fixed variance schedules + use_positional_encodings=False, + ): + super().__init__() + self.device = device + self.parameterization = parameterization + self.use_positional_encodings = use_positional_encodings + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + self.register_schedule( + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + self.device, + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + +class LatentDiffusion(DDPM): + def __init__( + self, + diffusion_model, + device, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs, + ): + self.num_timesteps_cond = 1 + self.scale_by_std = scale_by_std + super().__init__(device, *args, **kwargs) + self.diffusion_model = diffusion_model + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.num_downs = 2 + self.scale_factor = scale_factor + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def apply_model(self, x_noisy, t, cond): + # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128 + t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False) + x_recon = self.diffusion_model(x_noisy, t_emb, cond) + return x_recon + + +class LDM(InpaintModel): + name = "ldm" + pad_mod = 32 + is_erase_model = True + + def __init__(self, device, fp16: bool = True, **kwargs): + self.fp16 = fp16 + super().__init__(device) + self.device = device + + def init_model(self, device, **kwargs): + self.diffusion_model = load_jit_model( + LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5 + ) + self.cond_stage_model_decode = load_jit_model( + LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5 + ) + self.cond_stage_model_encode = load_jit_model( + LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5 + ) + if self.fp16 and "cuda" in str(device): + self.diffusion_model = self.diffusion_model.half() + self.cond_stage_model_decode = self.cond_stage_model_decode.half() + self.cond_stage_model_encode = self.cond_stage_model_encode.half() + + self.model = LatentDiffusion(self.diffusion_model, device) + + @staticmethod + def download(): + download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5) + download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5) + download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL), + get_cache_path_by_url(LDM_DECODE_MODEL_URL), + get_cache_path_by_url(LDM_ENCODE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + @torch.cuda.amp.autocast() + def forward(self, image, mask, config: InpaintRequest): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + # image [1,3,512,512] float32 + # mask: [1,1,512,512] float32 + # masked_image: [1,3,512,512] float32 + if config.ldm_sampler == LDMSampler.ddim: + sampler = DDIMSampler(self.model) + elif config.ldm_sampler == LDMSampler.plms: + sampler = PLMSSampler(self.model) + else: + raise ValueError() + + steps = config.ldm_steps + image = norm_img(image) + mask = norm_img(mask) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + masked_image = (1 - mask) * image + + mask = self._norm(mask) + masked_image = self._norm(masked_image) + + c = self.cond_stage_model_encode(masked_image) + torch.cuda.empty_cache() + + cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128 + c = torch.cat((c, cc), dim=1) # 1,4,128,128 + + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim = sampler.sample( + steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape + ) + torch.cuda.empty_cache() + x_samples_ddim = self.cond_stage_model_decode( + samples_ddim + ) # samples_ddim: 1, 3, 128, 128 float32 + torch.cuda.empty_cache() + + # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) + inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + # inpainted = (1 - mask) * image + mask * predicted_image + inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 + inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1] + return inpainted_image + + def _norm(self, tensor): + return tensor * 2.0 - 1.0 diff --git a/custom-demo/back-end/model/manga.py b/custom-demo/back-end/model/manga.py new file mode 100644 index 0000000..1f58251 --- /dev/null +++ b/custom-demo/back-end/model/manga.py @@ -0,0 +1,97 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import time +from loguru import logger + +from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model +from .base import InpaintModel +from iopaint.schema import InpaintRequest + + +MANGA_INPAINTOR_MODEL_URL = os.environ.get( + "MANGA_INPAINTOR_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit", +) +MANGA_INPAINTOR_MODEL_MD5 = os.environ.get( + "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c" +) + +MANGA_LINE_MODEL_URL = os.environ.get( + "MANGA_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/erika.jit", +) +MANGA_LINE_MODEL_MD5 = os.environ.get( + "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644" +) + + +class Manga(InpaintModel): + name = "manga" + pad_mod = 16 + is_erase_model = True + + def init_model(self, device, **kwargs): + self.inpaintor_model = load_jit_model( + MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5 + ) + self.line_model = load_jit_model( + MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5 + ) + self.seed = 42 + + @staticmethod + def download(): + download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5) + download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL), + get_cache_path_by_url(MANGA_LINE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def forward(self, image, mask, config: InpaintRequest): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + seed = self.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + gray_img = torch.from_numpy( + gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32) + ).to(self.device) + start = time.time() + lines = self.line_model(gray_img) + torch.cuda.empty_cache() + lines = torch.clamp(lines, 0, 255) + logger.info(f"erika_model time: {time.time() - start}") + + mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device) + mask = mask.permute(0, 3, 1, 2) + mask = torch.where(mask > 0.5, 1.0, 0.0) + noise = torch.randn_like(mask) + ones = torch.ones_like(mask) + + gray_img = gray_img / 255 * 2 - 1.0 + lines = lines / 255 * 2 - 1.0 + + start = time.time() + inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones) + logger.info(f"image_inpaintor_model time: {time.time() - start}") + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR) + return cur_res diff --git a/custom-demo/back-end/model/mat.py b/custom-demo/back-end/model/mat.py new file mode 100644 index 0000000..0c5360f --- /dev/null +++ b/custom-demo/back-end/model/mat.py @@ -0,0 +1,1945 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from iopaint.helper import ( + load_model, + get_cache_path_by_url, + norm_img, + download_model, +) +from iopaint.schema import InpaintRequest +from .base import InpaintModel +from .utils import ( + setup_filter, + Conv2dLayer, + FullyConnectedLayer, + conv2d_resample, + bias_act, + upsample2d, + activation_funcs, + MinibatchStdLayer, + to_2tuple, + normalize_2nd_moment, + set_seed, +) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + style_dim, # dimension of the style code + demodulate=True, # perfrom demodulation + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + ): + super().__init__() + self.demodulate = demodulate + + self.weight = torch.nn.Parameter( + torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]) + ) + self.out_channels = out_channels + self.kernel_size = kernel_size + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) + self.padding = self.kernel_size // 2 + self.up = up + self.down = down + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + + self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1) + + def forward(self, x, style): + batch, in_channels, height, width = x.shape + style = self.affine(style).view(batch, 1, in_channels, 1, 1) + weight = self.weight * self.weight_gain * style + + if self.demodulate: + decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() + weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) + + weight = weight.view( + batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size + ) + x = x.view(1, batch * in_channels, height, width) + x = conv2d_resample( + x=x, + w=weight, + f=self.resample_filter, + up=self.up, + down=self.down, + padding=self.padding, + groups=batch, + ) + out = x.view(batch, self.out_channels, *x.shape[2:]) + + return out + + +class StyleConv(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + style_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=False, # Enable noise input? + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + demodulate=True, # perform demodulation + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + up=up, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + ) + + self.use_noise = use_noise + self.resolution = resolution + if use_noise: + self.register_buffer("noise_const", torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.activation = activation + self.act_gain = activation_funcs[activation].def_gain + self.conv_clamp = conv_clamp + + def forward(self, x, style, noise_mode="random", gain=1): + x = self.conv(x, style) + + assert noise_mode in ["random", "const", "none"] + + if self.use_noise: + if noise_mode == "random": + xh, xw = x.size()[-2:] + noise = ( + torch.randn([x.shape[0], 1, xh, xw], device=x.device) + * self.noise_strength + ) + if noise_mode == "const": + noise = self.noise_const * self.noise_strength + x = x + noise + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + out = bias_act( + x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp + ) + + return out + + +class ToRGB(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + style_dim, + kernel_size=1, + resample_filter=[1, 3, 3, 1], + conv_clamp=None, + demodulate=False, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + ) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + + def forward(self, x, style, skip=None): + x = self.conv(x, style) + out = bias_act(x, self.bias, clamp=self.conv_clamp) + + if skip is not None: + if skip.shape != out.shape: + skip = upsample2d(skip, self.resample_filter) + out = out + skip + + return out + + +def get_style_code(a, b): + return torch.cat([a, b], dim=1) + + +class DecBlockFirst(nn.Module): + def __init__( + self, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.fc = FullyConnectedLayer( + in_features=in_channels * 2, + out_features=in_channels * 4**2, + activation=activation, + ) + self.conv = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = x + E_features[2] + style = get_style_code(ws[:, 0], gs) + x = self.conv(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlockFirstV2(nn.Module): + def __init__( + self, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + # x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = self.conv0(x) + x = x + E_features[2] + style = get_style_code(ws[:, 0], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlock(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): # res = 2, ..., resolution_log2 + super().__init__() + self.res = res + + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, ws, gs, E_features, noise_mode="random"): + style = get_style_code(ws[:, self.res * 2 - 5], gs) + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + E_features[self.res] + style = get_style_code(ws[:, self.res * 2 - 4], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, self.res * 2 - 3], gs) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class MappingNet(torch.nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + torch_dtype=torch.float32, + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + self.torch_dtype = torch_dtype + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = ( + [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + ) + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer( + in_features, + out_features, + activation=activation, + lr_multiplier=lr_multiplier, + ) + setattr(self, f"fc{idx}", layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer("w_avg", torch.zeros([w_dim])) + + def forward( + self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False + ): + # Embed, normalize, and concat inputs. + x = None + if self.z_dim > 0: + x = normalize_2nd_moment(z) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c)) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f"fc{idx}") + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) + + return x + + +class DisFromRGB(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 + super().__init__() + self.conv = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + + def forward(self, x): + return self.conv(x) + + +class DisBlock(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 + super().__init__() + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + down=2, + activation=activation, + ) + self.skip = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + down=2, + bias=False, + ) + + def forward(self, x): + skip = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + out = skip + x + + return out + + +class Discriminator(torch.nn.Module): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation="lrelu", + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 + self.resolution_log2 = resolution_log2 + + def nf(stage): + return np.clip( + int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max + ) + + if cmap_dim == None: + cmap_dim = nf(2) + if c_dim == 0: + cmap_dim = 0 + self.cmap_dim = cmap_dim + + if c_dim > 0: + self.mapping = MappingNet( + z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None + ) + + Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] + for res in range(resolution_log2, 2, -1): + Dis.append(DisBlock(nf(res), nf(res - 1), activation)) + + if mbstd_num_channels > 0: + Dis.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis.append( + Conv2dLayer( + nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation + ) + ) + self.Dis = nn.Sequential(*Dis) + + self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation) + self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, images_in, masks_in, c): + x = torch.cat([masks_in - 0.5, images_in], dim=1) + x = self.Dis(x) + x = self.fc1(self.fc0(x.flatten(start_dim=1))) + + if self.c_dim > 0: + cmap = self.mapping(None, c) + + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return x + + +def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): + NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} + return NF[2**stage] + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = FullyConnectedLayer( + in_features=in_features, out_features=hidden_features, activation="lrelu" + ) + self.fc2 = FullyConnectedLayer( + in_features=hidden_features, out_features=out_features + ) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + # B = windows.shape[0] / (H * W / window_size / window_size) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class Conv2dLayerPartial(nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.conv = Conv2dLayer( + in_channels, + out_channels, + kernel_size, + bias, + activation, + up, + down, + resample_filter, + conv_clamp, + trainable, + ) + + self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) + self.slide_winsize = kernel_size**2 + self.stride = down + self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 + + def forward(self, x, mask=None): + if mask is not None: + with torch.no_grad(): + if self.weight_maskUpdater.type() != x.type(): + self.weight_maskUpdater = self.weight_maskUpdater.to(x) + update_mask = F.conv2d( + mask, + self.weight_maskUpdater, + bias=None, + stride=self.stride, + padding=self.padding, + ) + mask_ratio = self.slide_winsize / (update_mask.to(torch.float32) + 1e-8) + update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1 + mask_ratio = torch.mul(mask_ratio, update_mask).to(x.dtype) + x = self.conv(x) + x = torch.mul(x, mask_ratio) + return x, update_mask + else: + x = self.conv(x) + return x, None + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + down_ratio=1, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = FullyConnectedLayer(in_features=dim, out_features=dim) + self.k = FullyConnectedLayer(in_features=dim, out_features=dim) + self.v = FullyConnectedLayer(in_features=dim, out_features=dim) + self.proj = FullyConnectedLayer(in_features=dim, out_features=dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask_windows=None, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + norm_x = F.normalize(x, p=2.0, dim=-1, eps=torch.finfo(x.dtype).eps) + q = ( + self.q(norm_x) + .reshape(B_, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.k(norm_x) + .view(B_, -1, self.num_heads, C // self.num_heads) + .permute(0, 2, 3, 1) + ) + v = ( + self.v(x) + .view(B_, -1, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k) * self.scale + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + if mask_windows is not None: + attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1) + attn = attn + attn_mask_windows.masked_fill( + attn_mask_windows == 0, float(-100.0) + ).masked_fill(attn_mask_windows == 1, float(0.0)) + with torch.no_grad(): + mask_windows = torch.clamp( + torch.sum(mask_windows, dim=1, keepdim=True), 0, 1 + ).repeat(1, N, 1) + + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x, mask_windows + + +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + down_ratio=1, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + if self.shift_size > 0: + down_ratio = 1 + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + down_ratio=down_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.fuse = FullyConnectedLayer( + in_features=dim * 2, out_features=dim, activation="lrelu" + ) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + return attn_mask + + def forward(self, x, x_size, mask=None): + # H, W = self.input_resolution + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + if mask is not None: + mask = mask.view(B, H, W, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + if mask is not None: + shifted_mask = torch.roll( + mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + if mask is not None: + shifted_mask = mask + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + if mask is not None: + mask_windows = window_partition(shifted_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1) + else: + mask_windows = None + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows, mask_windows = self.attn( + x_windows, mask_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + else: + attn_windows, mask_windows = self.attn( + x_windows, + mask_windows, + mask=self.calculate_mask(x_size).to(x.dtype).to(x.device), + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + if mask is not None: + mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1) + shifted_mask = window_reverse(mask_windows, self.window_size, H, W) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + if mask is not None: + mask = torch.roll( + shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + if mask is not None: + mask = shifted_mask + x = x.view(B, H * W, C) + if mask is not None: + mask = mask.view(B, H * W, 1) + + # FFN + x = self.fuse(torch.cat([shortcut, x], dim=-1)) + x = self.mlp(x) + + return x, mask + + +class PatchMerging(nn.Module): + def __init__(self, in_channels, out_channels, down=2): + super().__init__() + self.conv = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation="lrelu", + down=down, + ) + self.down = down + + def forward(self, x, x_size, mask=None): + x = token2feature(x, x_size) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(x, mask) + if self.down != 1: + ratio = 1 / self.down + x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio)) + x = feature2token(x) + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class PatchUpsampling(nn.Module): + def __init__(self, in_channels, out_channels, up=2): + super().__init__() + self.conv = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation="lrelu", + up=up, + ) + self.up = up + + def forward(self, x, x_size, mask=None): + x = token2feature(x, x_size) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(x, mask) + if self.up != 1: + x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up)) + x = feature2token(x) + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + down_ratio=1, + mlp_ratio=2.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # patch merging layer + if downsample is not None: + # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample + else: + self.downsample = None + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + down_ratio=down_ratio, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + self.conv = Conv2dLayerPartial( + in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu" + ) + + def forward(self, x, x_size, mask=None): + if self.downsample is not None: + x, x_size, mask = self.downsample(x, x_size, mask) + identity = x + for blk in self.blocks: + if self.use_checkpoint: + x, mask = checkpoint.checkpoint(blk, x, x_size, mask) + else: + x, mask = blk(x, x_size, mask) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(token2feature(x, x_size), mask) + x = feature2token(x) + identity + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class ToToken(nn.Module): + def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1): + super().__init__() + + self.proj = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=dim, + kernel_size=kernel_size, + activation="lrelu", + ) + + def forward(self, x, mask): + x, mask = self.proj(x, mask) + + return x, mask + + +class EncFromRGB(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 + super().__init__() + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + self.conv1 = Conv2dLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) + + def forward(self, x): + x = self.conv0(x) + x = self.conv1(x) + + return x + + +class ConvBlockDown(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log + super().__init__() + + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + down=2, + ) + self.conv1 = Conv2dLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) + + def forward(self, x): + x = self.conv0(x) + x = self.conv1(x) + + return x + + +def token2feature(x, x_size): + B, N, C = x.shape + h, w = x_size + x = x.permute(0, 2, 1).reshape(B, C, h, w) + return x + + +def feature2token(x): + B, C, H, W = x.shape + x = x.view(B, C, -1).transpose(1, 2) + return x + + +class Encoder(nn.Module): + def __init__( + self, + res_log2, + img_channels, + activation, + patch_size=5, + channels=16, + drop_path_rate=0.1, + ): + super().__init__() + + self.resolution = [] + + for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16 + res = 2**i + self.resolution.append(res) + if i == res_log2: + block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) + else: + block = ConvBlockDown(nf(i + 1), nf(i), activation) + setattr(self, "EncConv_Block_%dx%d" % (res, res), block) + + def forward(self, x): + out = {} + for res in self.resolution: + res_log2 = int(np.log2(res)) + x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x) + out[res_log2] = x + + return out + + +class ToStyle(nn.Module): + def __init__(self, in_channels, out_channels, activation, drop_rate): + super().__init__() + self.conv = nn.Sequential( + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + ) + + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = FullyConnectedLayer( + in_features=in_channels, out_features=out_channels, activation=activation + ) + # self.dropout = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + x = self.fc(x.flatten(start_dim=1)) + # x = self.dropout(x) + + return x + + +class DecBlockFirstV2(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.res = res + + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + # x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = self.conv0(x) + x = x + E_features[self.res] + style = get_style_code(ws[:, 0], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlock(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): # res = 4, ..., resolution_log2 + super().__init__() + self.res = res + + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, ws, gs, E_features, noise_mode="random"): + style = get_style_code(ws[:, self.res * 2 - 9], gs) + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + E_features[self.res] + style = get_style_code(ws[:, self.res * 2 - 8], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, self.res * 2 - 7], gs) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class Decoder(nn.Module): + def __init__( + self, res_log2, activation, style_dim, use_noise, demodulate, img_channels + ): + super().__init__() + self.Dec_16x16 = DecBlockFirstV2( + 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels + ) + for res in range(5, res_log2 + 1): + setattr( + self, + "Dec_%dx%d" % (2**res, 2**res), + DecBlock( + res, + nf(res - 1), + nf(res), + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ), + ) + self.res_log2 = res_log2 + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) + for res in range(5, self.res_log2 + 1): + block = getattr(self, "Dec_%dx%d" % (2**res, 2**res)) + x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) + + return img + + +class DecStyleBlock(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.res = res + + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, style, skip, noise_mode="random"): + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + skip + x = self.conv1(x, style, noise_mode=noise_mode) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class FirstStage(nn.Module): + def __init__( + self, + img_channels, + img_resolution=256, + dim=180, + w_dim=512, + use_noise=False, + demodulate=True, + activation="lrelu", + ): + super().__init__() + res = 64 + + self.conv_first = Conv2dLayerPartial( + in_channels=img_channels + 1, + out_channels=dim, + kernel_size=3, + activation=activation, + ) + self.enc_conv = nn.ModuleList() + down_time = int(np.log2(img_resolution // res)) + # 根据图片尺寸构建 swim transformer 的层数 + for i in range(down_time): # from input size to 64 + self.enc_conv.append( + Conv2dLayerPartial( + in_channels=dim, + out_channels=dim, + kernel_size=3, + down=2, + activation=activation, + ) + ) + + # from 64 -> 16 -> 64 + depths = [2, 3, 4, 3, 2] + ratios = [1, 1 / 2, 1 / 2, 2, 2] + num_heads = 6 + window_sizes = [8, 16, 16, 16, 8] + drop_path_rate = 0.1 + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.tran = nn.ModuleList() + for i, depth in enumerate(depths): + res = int(res * ratios[i]) + if ratios[i] < 1: + merge = PatchMerging(dim, dim, down=int(1 / ratios[i])) + elif ratios[i] > 1: + merge = PatchUpsampling(dim, dim, up=ratios[i]) + else: + merge = None + self.tran.append( + BasicLayer( + dim=dim, + input_resolution=[res, res], + depth=depth, + num_heads=num_heads, + window_size=window_sizes[i], + drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + downsample=merge, + ) + ) + + # global style + down_conv = [] + for i in range(int(np.log2(16))): + down_conv.append( + Conv2dLayer( + in_channels=dim, + out_channels=dim, + kernel_size=3, + down=2, + activation=activation, + ) + ) + down_conv.append(nn.AdaptiveAvgPool2d((1, 1))) + self.down_conv = nn.Sequential(*down_conv) + self.to_style = FullyConnectedLayer( + in_features=dim, out_features=dim * 2, activation=activation + ) + self.ws_style = FullyConnectedLayer( + in_features=w_dim, out_features=dim, activation=activation + ) + self.to_square = FullyConnectedLayer( + in_features=dim, out_features=16 * 16, activation=activation + ) + + style_dim = dim * 3 + self.dec_conv = nn.ModuleList() + for i in range(down_time): # from 64 to input size + res = res * 2 + self.dec_conv.append( + DecStyleBlock( + res, + dim, + dim, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ) + ) + + def forward(self, images_in, masks_in, ws, noise_mode="random"): + x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1) + + skips = [] + x, mask = self.conv_first(x, masks_in) # input size + skips.append(x) + for i, block in enumerate(self.enc_conv): # input size to 64 + x, mask = block(x, mask) + if i != len(self.enc_conv) - 1: + skips.append(x) + + x_size = x.size()[-2:] + x = feature2token(x) + mask = feature2token(mask) + mid = len(self.tran) // 2 + for i, block in enumerate(self.tran): # 64 to 16 + if i < mid: + x, x_size, mask = block(x, x_size, mask) + skips.append(x) + elif i > mid: + x, x_size, mask = block(x, x_size, None) + x = x + skips[mid - i] + else: + x, x_size, mask = block(x, x_size, None) + + mul_map = torch.ones_like(x) * 0.5 + mul_map = F.dropout(mul_map, training=True) + ws = self.ws_style(ws[:, -1]) + add_n = self.to_square(ws).unsqueeze(1) + add_n = ( + F.interpolate( + add_n, size=x.size(1), mode="linear", align_corners=False + ) + .squeeze(1) + .unsqueeze(-1) + ) + x = x * mul_map + add_n * (1 - mul_map) + gs = self.to_style( + self.down_conv(token2feature(x, x_size)).flatten(start_dim=1) + ) + style = torch.cat([gs, ws], dim=1) + + x = token2feature(x, x_size).contiguous() + img = None + for i, block in enumerate(self.dec_conv): + x, img = block( + x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode + ) + + # ensemble + img = img * (1 - masks_in) + images_in * masks_in + + return img + + +class SynthesisNet(nn.Module): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels=3, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_decay=1.0, + channel_max=512, # Maximum number of channels in any layer. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + drop_rate=0.5, + use_noise=False, + demodulate=True, + ): + super().__init__() + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 + + self.num_layers = resolution_log2 * 2 - 3 * 2 + self.img_resolution = img_resolution + self.resolution_log2 = resolution_log2 + + # first stage + self.first_stage = FirstStage( + img_channels, + img_resolution=img_resolution, + w_dim=w_dim, + use_noise=False, + demodulate=demodulate, + ) + + # second stage + self.enc = Encoder( + resolution_log2, img_channels, activation, patch_size=5, channels=16 + ) + self.to_square = FullyConnectedLayer( + in_features=w_dim, out_features=16 * 16, activation=activation + ) + self.to_style = ToStyle( + in_channels=nf(4), + out_channels=nf(2) * 2, + activation=activation, + drop_rate=drop_rate, + ) + style_dim = w_dim + nf(2) * 2 + self.dec = Decoder( + resolution_log2, activation, style_dim, use_noise, demodulate, img_channels + ) + + def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False): + out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode) + + # encoder + x = images_in * masks_in + out_stg1 * (1 - masks_in) + x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1) + E_features = self.enc(x) + + fea_16 = E_features[4] + mul_map = torch.ones_like(fea_16) * 0.5 + mul_map = F.dropout(mul_map, training=True) + add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1) + add_n = F.interpolate( + add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False + ) + fea_16 = fea_16 * mul_map + add_n * (1 - mul_map) + E_features[4] = fea_16 + + # style + gs = self.to_style(fea_16) + + # decoder + img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode) + + # ensemble + img = img * (1 - masks_in) + images_in * masks_in + + if not return_stg1: + return img + else: + return img, out_stg1 + + +class Generator(nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # resolution of generated image + img_channels, # Number of input color channels. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.synthesis = SynthesisNet( + w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs, + ) + self.mapping = MappingNet( + z_dim=z_dim, + c_dim=c_dim, + w_dim=w_dim, + num_ws=self.synthesis.num_layers, + **mapping_kwargs, + ) + + def forward( + self, + images_in, + masks_in, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + skip_w_avg_update=False, + noise_mode="none", + return_stg1=False, + ): + ws = self.mapping( + z, + c, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + skip_w_avg_update=skip_w_avg_update, + ) + img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode) + return img + + +class Discriminator(torch.nn.Module): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation="lrelu", + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 + self.resolution_log2 = resolution_log2 + + if cmap_dim == None: + cmap_dim = nf(2) + if c_dim == 0: + cmap_dim = 0 + self.cmap_dim = cmap_dim + + if c_dim > 0: + self.mapping = MappingNet( + z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None + ) + + Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] + for res in range(resolution_log2, 2, -1): + Dis.append(DisBlock(nf(res), nf(res - 1), activation)) + + if mbstd_num_channels > 0: + Dis.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis.append( + Conv2dLayer( + nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation + ) + ) + self.Dis = nn.Sequential(*Dis) + + self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation) + self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) + + # for 64x64 + Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)] + for res in range(resolution_log2, 2, -1): + Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation)) + + if mbstd_num_channels > 0: + Dis_stg1.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis_stg1.append( + Conv2dLayer( + nf(2) // 2 + mbstd_num_channels, + nf(2) // 2, + kernel_size=3, + activation=activation, + ) + ) + self.Dis_stg1 = nn.Sequential(*Dis_stg1) + + self.fc0_stg1 = FullyConnectedLayer( + nf(2) // 2 * 4**2, nf(2) // 2, activation=activation + ) + self.fc1_stg1 = FullyConnectedLayer( + nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim + ) + + def forward(self, images_in, masks_in, images_stg1, c): + x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1)) + x = self.fc1(self.fc0(x.flatten(start_dim=1))) + + x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1)) + x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1))) + + if self.c_dim > 0: + cmap = self.mapping(None, c) + + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * ( + 1 / np.sqrt(self.cmap_dim) + ) + + return x, x_stg1 + + +MAT_MODEL_URL = os.environ.get( + "MAT_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth", +) + +MAT_MODEL_MD5 = os.environ.get("MAT_MODEL_MD5", "8ca927835fa3f5e21d65ffcb165377ed") + + +class MAT(InpaintModel): + name = "mat" + min_size = 512 + pad_mod = 512 + pad_to_square = True + is_erase_model = True + + def init_model(self, device, **kwargs): + seed = 240 # pick up a random number + set_seed(seed) + + fp16 = not kwargs.get("no_half", False) + use_gpu = "cuda" in str(device) and torch.cuda.is_available() + self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + G = Generator( + z_dim=512, + c_dim=0, + w_dim=512, + img_resolution=512, + img_channels=3, + mapping_kwargs={"torch_dtype": self.torch_dtype}, + ).to(self.torch_dtype) + # fmt: off + self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5) + self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device) + self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype) + # fmt: on + + @staticmethod + def download(): + download_model(MAT_MODEL_URL, MAT_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL)) + + def forward(self, image, mask, config: InpaintRequest): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + + mask = (mask > 127) * 255 + mask = 255 - mask + mask = norm_img(mask) + + image = ( + torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device) + ) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device) + + output = self.model( + image, mask, self.z, self.label, truncation_psi=1, noise_mode="none" + ) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/custom-demo/back-end/model/mi_gan.py b/custom-demo/back-end/model/mi_gan.py new file mode 100644 index 0000000..f1ce25f --- /dev/null +++ b/custom-demo/back-end/model/mi_gan.py @@ -0,0 +1,110 @@ +import os + +import cv2 +import torch + +from iopaint.helper import ( + load_jit_model, + download_model, + get_cache_path_by_url, + boxes_from_mask, + resize_max_size, + norm_img, +) +from .base import InpaintModel +from iopaint.schema import InpaintRequest + +MIGAN_MODEL_URL = os.environ.get( + "MIGAN_MODEL_URL", + "https://github.com/Sanster/models/releases/download/migan/migan_traced.pt", +) +MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c") + + +class MIGAN(InpaintModel): + name = "migan" + min_size = 512 + pad_mod = 512 + pad_to_square = True + is_erase_model = True + + def init_model(self, device, **kwargs): + self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval() + + @staticmethod + def download(): + download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL)) + + @torch.no_grad() + def __call__(self, image, mask, config: InpaintRequest): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if image.shape[0] == 512 and image.shape[1] == 512: + return self._pad_forward(image, mask, config) + + boxes = boxes_from_mask(mask) + crop_result = [] + config.hd_strategy_crop_margin = 128 + for box in boxes: + crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config) + origin_size = crop_image.shape[:2] + resize_image = resize_max_size(crop_image, size_limit=512) + resize_mask = resize_max_size(crop_mask, size_limit=512) + inpaint_result = self._pad_forward(resize_image, resize_mask, config) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + + original_pixel_indices = crop_mask < 127 + inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][ + original_pixel_indices + ] + + crop_result.append((inpaint_result, crop_box)) + + inpaint_result = image[:, :, ::-1].copy() + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + return inpaint_result + + def forward(self, image, mask, config: InpaintRequest): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + mask = (mask > 120) * 255 + mask = norm_img(mask) + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + erased_img = image * (1 - mask) + input_image = torch.cat([0.5 - mask, erased_img], dim=1) + + output = self.model(input_image) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/custom-demo/back-end/model/opencv2.py b/custom-demo/back-end/model/opencv2.py new file mode 100644 index 0000000..de47209 --- /dev/null +++ b/custom-demo/back-end/model/opencv2.py @@ -0,0 +1,29 @@ +import cv2 +from .base import InpaintModel +from iopaint.schema import InpaintRequest + +flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA} + + +class OpenCV2(InpaintModel): + name = "cv2" + pad_mod = 1 + is_erase_model = True + + @staticmethod + def is_downloaded() -> bool: + return True + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + cur_res = cv2.inpaint( + image[:, :, ::-1], + mask, + inpaintRadius=config.cv2_radius, + flags=flag_map[config.cv2_flag], + ) + return cur_res diff --git a/custom-demo/back-end/model/original_sd_configs/__init__.py b/custom-demo/back-end/model/original_sd_configs/__init__.py new file mode 100644 index 0000000..23896a7 --- /dev/null +++ b/custom-demo/back-end/model/original_sd_configs/__init__.py @@ -0,0 +1,19 @@ +from pathlib import Path +from typing import Dict + +CURRENT_DIR = Path(__file__).parent.absolute() + + +def get_config_files() -> Dict[str, Path]: + """ + - `v1`: Config file for Stable Diffusion v1 + - `v2`: Config file for Stable Diffusion v2 + - `xl`: Config file for Stable Diffusion XL + - `xl_refiner`: Config file for Stable Diffusion XL Refiner + """ + return { + "v1": CURRENT_DIR / "v1-inference.yaml", + "v2": CURRENT_DIR / "v2-inference-v.yaml", + "xl": CURRENT_DIR / "sd_xl_base.yaml", + "xl_refiner": CURRENT_DIR / "sd_xl_refiner.yaml", + } diff --git a/custom-demo/back-end/model/original_sd_configs/sd_xl_base.yaml b/custom-demo/back-end/model/original_sd_configs/sd_xl_base.yaml new file mode 100644 index 0000000..6047379 --- /dev/null +++ b/custom-demo/back-end/model/original_sd_configs/sd_xl_base.yaml @@ -0,0 +1,93 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/custom-demo/back-end/model/original_sd_configs/sd_xl_refiner.yaml b/custom-demo/back-end/model/original_sd_configs/sd_xl_refiner.yaml new file mode 100644 index 0000000..2d5ab44 --- /dev/null +++ b/custom-demo/back-end/model/original_sd_configs/sd_xl_refiner.yaml @@ -0,0 +1,86 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2560 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 384 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 4 + context_dim: [1280, 1280, 1280, 1280] + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + legacy: False + freeze: True + layer: penultimate + always_return_pooled: True + + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: aesthetic_score + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/custom-demo/back-end/model/original_sd_configs/v1-inference.yaml b/custom-demo/back-end/model/original_sd_configs/v1-inference.yaml new file mode 100644 index 0000000..d4effe5 --- /dev/null +++ b/custom-demo/back-end/model/original_sd_configs/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/custom-demo/back-end/model/original_sd_configs/v2-inference-v.yaml b/custom-demo/back-end/model/original_sd_configs/v2-inference-v.yaml new file mode 100644 index 0000000..8ec8dfb --- /dev/null +++ b/custom-demo/back-end/model/original_sd_configs/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/custom-demo/back-end/model/paint_by_example.py b/custom-demo/back-end/model/paint_by_example.py new file mode 100644 index 0000000..bf1e5b7 --- /dev/null +++ b/custom-demo/back-end/model/paint_by_example.py @@ -0,0 +1,68 @@ +import PIL +import PIL.Image +import cv2 +import torch +from loguru import logger + +from iopaint.helper import decode_base64_to_image +from .base import DiffusionInpaintModel +from iopaint.schema import InpaintRequest +from .utils import get_torch_dtype, enable_low_mem, is_local_files_only + + +class PaintByExample(DiffusionInpaintModel): + name = "Fantasy-Studio/Paint-by-Example" + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers import DiffusionPipeline + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + model_kwargs = { + "local_files_only": is_local_files_only(**kwargs), + } + + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Paint By Example Model NSFW checker") + model_kwargs.update( + dict(safety_checker=None, requires_safety_checker=False) + ) + + self.model = DiffusionPipeline.from_pretrained( + self.name, torch_dtype=torch_dtype, **model_kwargs + ) + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + # TODO: gpu_id + if kwargs.get("cpu_offload", False) and use_gpu: + self.model.image_encoder = self.model.image_encoder.to(device) + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + if config.paint_by_example_example_image is None: + raise ValueError("paint_by_example_example_image is required") + example_image, _, _ = decode_base64_to_image( + config.paint_by_example_example_image + ) + output = self.model( + image=PIL.Image.fromarray(image), + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + example_image=PIL.Image.fromarray(example_image), + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature", + output_type="np.array", + generator=torch.manual_seed(config.sd_seed), + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/custom-demo/back-end/model/plms_sampler.py b/custom-demo/back-end/model/plms_sampler.py new file mode 100644 index 0000000..131a8f4 --- /dev/null +++ b/custom-demo/back-end/model/plms_sampler.py @@ -0,0 +1,225 @@ +# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +import torch +import numpy as np +from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like +from tqdm import tqdm + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + steps, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + return img + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/custom-demo/back-end/model/power_paint/__init__.py b/custom-demo/back-end/model/power_paint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/model/power_paint/pipeline_powerpaint.py b/custom-demo/back-end/model/power_paint/pipeline_powerpaint.py new file mode 100644 index 0000000..13c1d27 --- /dev/null +++ b/custom-demo/back-end/model/power_paint/pipeline_powerpaint.py @@ -0,0 +1,1243 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image( + image, mask, height, width, return_image: bool = False +): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError( + f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" + ) + + # Batch single image + if image.ndim == 3: + assert ( + image.shape[0] == 3 + ), "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert ( + image.ndim == 4 and mask.ndim == 4 + ), "Image and Mask must have 4 dimensions" + assert ( + image.shape[-2:] == mask.shape[-2:] + ), "Image and Mask must have the same spatial dimensions" + assert ( + image.shape[0] == mask.shape[0] + ), "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError( + f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" + ) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [ + i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image + ] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate( + [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 + ) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: Union[AutoencoderKL, AsymmetricAutoencoderKL], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if ( + hasattr(scheduler.config, "skip_prk_steps") + and scheduler.config.skip_prk_steps is False + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate( + "skip_prk_steps not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.info( + f"You have loaded a UNet with {unet.config.in_channels} input channels which." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a + time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. + Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the + iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook + ) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook( + self.safety_checker, device, prev_module_hook=hook + ) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + promptA, + promptB, + t, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA=None, + negative_promptB=None, + t_nag=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + prompt = promptA + negative_prompt = negative_promptA + + if promptA is not None and isinstance(promptA, str): + batch_size = 1 + elif promptA is not None and isinstance(promptA, list): + batch_size = len(promptA) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + promptA = self.maybe_convert_prompt(promptA, self.tokenizer) + + text_inputsA = self.tokenizer( + promptA, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputsB = self.tokenizer( + promptB, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_idsA = text_inputsA.input_ids + text_input_idsB = text_inputsB.input_ids + untruncated_ids = self.tokenizer( + promptA, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_idsA.shape[ + -1 + ] and not torch.equal(text_input_idsA, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputsA.attention_mask.to(device) + else: + attention_mask = None + + # print("text_input_idsA: ",text_input_idsA) + # print("text_input_idsB: ",text_input_idsB) + # print('t: ',t) + + prompt_embedsA = self.text_encoder( + text_input_idsA.to(device), + attention_mask=attention_mask, + ) + prompt_embedsA = prompt_embedsA[0] + + prompt_embedsB = self.text_encoder( + text_input_idsB.to(device), + attention_mask=attention_mask, + ) + prompt_embedsB = prompt_embedsB[0] + prompt_embeds = prompt_embedsA * (t) + (1 - t) * prompt_embedsB + # print("prompt_embeds: ",prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokensA: List[str] + uncond_tokensB: List[str] + if negative_prompt is None: + uncond_tokensA = [""] * batch_size + uncond_tokensB = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokensA = [negative_promptA] + uncond_tokensB = [negative_promptB] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokensA = negative_promptA + uncond_tokensB = negative_promptB + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokensA = self.maybe_convert_prompt( + uncond_tokensA, self.tokenizer + ) + uncond_tokensB = self.maybe_convert_prompt( + uncond_tokensB, self.tokenizer + ) + + max_length = prompt_embeds.shape[1] + uncond_inputA = self.tokenizer( + uncond_tokensA, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_inputB = self.tokenizer( + uncond_tokensB, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_inputA.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embedsA = self.text_encoder( + uncond_inputA.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embedsB = self.text_encoder( + uncond_inputB.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = ( + negative_prompt_embedsA[0] * (t_nag) + + (1 - t_nag) * negative_prompt_embedsB[0] + ) + + # negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # print("prompt_embeds: ",prompt_embeds) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if strength < 0 or strength > 1: + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = ( + noise + if is_strength_max + else self.scheduler.add_noise(image_latents, noise, timestep) + ) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = ( + latents * self.scheduler.init_noise_sigma + if is_strength_max + else latents + ) + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample( + generator=generator[i] + ) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample( + generator=generator + ) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) + if do_classifier_free_guidance + else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, + promptA: Union[str, List[str]] = None, + promptB: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + tradoff: float = 1.0, + tradoff_nag: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_promptA: Optional[Union[str, List[str]]] = None, + negative_promptB: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + task_class: Union[torch.Tensor, float, int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked + out with `mask_image` and repainted according to `prompt`). + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + prompt = promptA + negative_prompt = negative_promptA + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None + ) + prompt_embeds = self._encode_prompt( + promptA, + promptB, + tradoff, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA, + negative_promptB, + tradoff_nag, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + mask_condition = mask.clone() + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.unet.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + if num_channels_unet == 9: + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) + + # predict the noise residual + if task_class is not None: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + task_class=task_class, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = ( + 1 - init_mask + ) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(self, i, t, {}) + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to( + device=device, dtype=masked_image_latents.dtype + ) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to( + device=device, dtype=masked_image_latents.dtype + ) + condition_kwargs = { + "image": init_image_condition, + "mask": mask_condition, + } + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + **condition_kwargs, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/custom-demo/back-end/model/power_paint/pipeline_powerpaint_controlnet.py b/custom-demo/back-end/model/power_paint/pipeline_powerpaint_controlnet.py new file mode 100644 index 0000000..cba0f8f --- /dev/null +++ b/custom-demo/back-end/model/power_paint/pipeline_powerpaint_controlnet.py @@ -0,0 +1,1775 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor,is_compiled_module +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.controlnet import MultiControlNetModel + + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image + + + >>> control_image = make_inpaint_condition(init_image, mask_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + + + This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as + [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) + as well as default text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on + those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + promptA, + promptB, + t, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA=None, + negative_promptB=None, + t_nag = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + prompt = promptA + negative_prompt = negative_promptA + + if promptA is not None and isinstance(promptA, str): + batch_size = 1 + elif promptA is not None and isinstance(promptA, list): + batch_size = len(promptA) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + promptA = self.maybe_convert_prompt(promptA, self.tokenizer) + + text_inputsA = self.tokenizer( + promptA, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputsB = self.tokenizer( + promptB, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_idsA = text_inputsA.input_ids + text_input_idsB = text_inputsB.input_ids + untruncated_ids = self.tokenizer(promptA, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_idsA.shape[-1] and not torch.equal( + text_input_idsA, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputsA.attention_mask.to(device) + else: + attention_mask = None + + # print("text_input_idsA: ",text_input_idsA) + # print("text_input_idsB: ",text_input_idsB) + # print('t: ',t) + + prompt_embedsA = self.text_encoder( + text_input_idsA.to(device), + attention_mask=attention_mask, + ) + prompt_embedsA = prompt_embedsA[0] + + prompt_embedsB = self.text_encoder( + text_input_idsB.to(device), + attention_mask=attention_mask, + ) + prompt_embedsB = prompt_embedsB[0] + prompt_embeds = prompt_embedsA*(t)+(1-t)*prompt_embedsB + # print("prompt_embeds: ",prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokensA: List[str] + uncond_tokensB: List[str] + if negative_prompt is None: + uncond_tokensA = [""] * batch_size + uncond_tokensB = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokensA = [negative_promptA] + uncond_tokensB = [negative_promptB] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokensA = negative_promptA + uncond_tokensB = negative_promptB + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokensA = self.maybe_convert_prompt(uncond_tokensA, self.tokenizer) + uncond_tokensB = self.maybe_convert_prompt(uncond_tokensB, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_inputA = self.tokenizer( + uncond_tokensA, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_inputB = self.tokenizer( + uncond_tokensB, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_inputA.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embedsA = self.text_encoder( + uncond_inputA.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embedsB = self.text_encoder( + uncond_inputB.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embedsA[0]*(t_nag)+(1-t_nag)*negative_prompt_embedsB[0] + + # negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # print("prompt_embeds: ",prompt_embeds) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + @torch.no_grad() + def predict_woControl( + self, + promptA: Union[str, List[str]] = None, + promptB: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + tradoff: float = 1.0, + tradoff_nag: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_promptA: Optional[Union[str, List[str]]] = None, + negative_promptB: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + task_class: Union[torch.Tensor, float, int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked + out with `mask_image` and repainted according to `prompt`). + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + prompt = promptA + negative_prompt = negative_promptA + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + promptA, + promptB, + tradoff, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA, + negative_promptB, + tradoff_nag, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + mask_condition = mask.clone() + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + if task_class is not None: + noise_pred = self.unet( + sample = latent_model_input, + timestep = t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + task_class = task_class, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + promptA: Union[str, List[str]] = None, + promptB: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + tradoff: float = 1.0, + tradoff_nag: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_promptA: Optional[Union[str, List[str]]] = None, + negative_promptB: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + prompt = promptA + negative_prompt = negative_promptA + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + promptA, + promptB, + tradoff, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_promptA, + negative_promptB, + tradoff_nag, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/custom-demo/back-end/model/power_paint/power_paint.py b/custom-demo/back-end/model/power_paint/power_paint.py new file mode 100644 index 0000000..f17a5a3 --- /dev/null +++ b/custom-demo/back-end/model/power_paint/power_paint.py @@ -0,0 +1,101 @@ +from PIL import Image +import PIL.Image +import cv2 +import torch +from loguru import logger + +from ..base import DiffusionInpaintModel +from ..helper.cpu_text_encoder import CPUTextEncoderWrapper +from ..utils import ( + handle_from_pretrained_exceptions, + get_torch_dtype, + enable_low_mem, + is_local_files_only, +) +from iopaint.schema import InpaintRequest +from .powerpaint_tokenizer import add_task_to_prompt +from ...const import POWERPAINT_NAME + + +class PowerPaint(DiffusionInpaintModel): + name = POWERPAINT_NAME + pad_mod = 8 + min_size = 512 + lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" + + def init_model(self, device: torch.device, **kwargs): + from .pipeline_powerpaint import StableDiffusionInpaintPipeline + from .powerpaint_tokenizer import PowerPaintTokenizer + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + model_kwargs = {"local_files_only": is_local_files_only(**kwargs)} + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + self.model = handle_from_pretrained_exceptions( + StableDiffusionInpaintPipeline.from_pretrained, + pretrained_model_name_or_path=self.name, + variant="fp16", + torch_dtype=torch_dtype, + **model_kwargs, + ) + self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer) + + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + self.set_scheduler(config) + + img_h, img_w = image.shape[:2] + promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt( + config.prompt, config.negative_prompt, config.powerpaint_task + ) + + output = self.model( + image=PIL.Image.fromarray(image), + promptA=promptA, + promptB=promptB, + tradoff=config.fitting_degree, + tradoff_nag=config.fitting_degree, + negative_promptA=negative_promptA, + negative_promptB=negative_promptB, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + strength=config.sd_strength, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + callback_steps=1, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/custom-demo/back-end/model/power_paint/powerpaint_tokenizer.py b/custom-demo/back-end/model/power_paint/powerpaint_tokenizer.py new file mode 100644 index 0000000..39d5cb7 --- /dev/null +++ b/custom-demo/back-end/model/power_paint/powerpaint_tokenizer.py @@ -0,0 +1,540 @@ +import torch +import torch.nn as nn +import copy +import random +from typing import Any, List, Optional, Union +from transformers import CLIPTokenizer + +from iopaint.schema import PowerPaintTask + + +def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask): + if task == PowerPaintTask.object_remove: + promptA = prompt + " P_ctxt" + promptB = prompt + " P_ctxt" + negative_promptA = negative_prompt + " P_obj" + negative_promptB = negative_prompt + " P_obj" + elif task == PowerPaintTask.shape_guided: + promptA = prompt + " P_shape" + promptB = prompt + " P_ctxt" + negative_promptA = negative_prompt + negative_promptB = negative_prompt + elif task == PowerPaintTask.outpainting: + promptA = prompt + " P_ctxt" + promptB = prompt + " P_ctxt" + negative_promptA = negative_prompt + " P_obj" + negative_promptB = negative_prompt + " P_obj" + else: + promptA = prompt + " P_obj" + promptB = prompt + " P_obj" + negative_promptA = negative_prompt + negative_promptB = negative_prompt + + return promptA, promptB, negative_promptA, negative_promptB + + +class PowerPaintTokenizer: + def __init__(self, tokenizer: CLIPTokenizer): + self.wrapped = tokenizer + self.token_map = {} + placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"] + num_vec_per_token = 10 + for placeholder_token in placeholder_tokens: + output = [] + for i in range(num_vec_per_token): + ith_token = placeholder_token + f"_{i}" + output.append(ith_token) + self.token_map[placeholder_token] = output + + def __getattr__(self, name: str) -> Any: + if name == "wrapped": + return super().__getattr__("wrapped") + + try: + return getattr(self.wrapped, name) + except AttributeError: + try: + return super().__getattr__(name) + except AttributeError: + raise AttributeError( + "'name' cannot be found in both " + f"'{self.__class__.__name__}' and " + f"'{self.__class__.__name__}.tokenizer'." + ) + + def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs): + """Attempt to add tokens to the tokenizer. + + Args: + tokens (Union[str, List[str]]): The tokens to be added. + """ + num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs) + assert num_added_tokens != 0, ( + f"The tokenizer already contains the token {tokens}. Please pass " + "a different `placeholder_token` that is not already in the " + "tokenizer." + ) + + def get_token_info(self, token: str) -> dict: + """Get the information of a token, including its start and end index in + the current tokenizer. + + Args: + token (str): The token to be queried. + + Returns: + dict: The information of the token, including its start and end + index in current tokenizer. + """ + token_ids = self.__call__(token).input_ids + start, end = token_ids[1], token_ids[-2] + 1 + return {"name": token, "start": start, "end": end} + + def add_placeholder_token( + self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs + ): + """Add placeholder tokens to the tokenizer. + + Args: + placeholder_token (str): The placeholder token to be added. + num_vec_per_token (int, optional): The number of vectors of + the added placeholder token. + *args, **kwargs: The arguments for `self.wrapped.add_tokens`. + """ + output = [] + if num_vec_per_token == 1: + self.try_adding_tokens(placeholder_token, *args, **kwargs) + output.append(placeholder_token) + else: + output = [] + for i in range(num_vec_per_token): + ith_token = placeholder_token + f"_{i}" + self.try_adding_tokens(ith_token, *args, **kwargs) + output.append(ith_token) + + for token in self.token_map: + if token in placeholder_token: + raise ValueError( + f"The tokenizer already has placeholder token {token} " + f"that can get confused with {placeholder_token} " + "keep placeholder tokens independent" + ) + self.token_map[placeholder_token] = output + + def replace_placeholder_tokens_in_text( + self, + text: Union[str, List[str]], + vector_shuffle: bool = False, + prop_tokens_to_load: float = 1.0, + ) -> Union[str, List[str]]: + """Replace the keywords in text with placeholder tokens. This function + will be called in `self.__call__` and `self.encode`. + + Args: + text (Union[str, List[str]]): The text to be processed. + vector_shuffle (bool, optional): Whether to shuffle the vectors. + Defaults to False. + prop_tokens_to_load (float, optional): The proportion of tokens to + be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0. + + Returns: + Union[str, List[str]]: The processed text. + """ + if isinstance(text, list): + output = [] + for i in range(len(text)): + output.append( + self.replace_placeholder_tokens_in_text( + text[i], vector_shuffle=vector_shuffle + ) + ) + return output + + for placeholder_token in self.token_map: + if placeholder_token in text: + tokens = self.token_map[placeholder_token] + tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] + if vector_shuffle: + tokens = copy.copy(tokens) + random.shuffle(tokens) + text = text.replace(placeholder_token, " ".join(tokens)) + return text + + def replace_text_with_placeholder_tokens( + self, text: Union[str, List[str]] + ) -> Union[str, List[str]]: + """Replace the placeholder tokens in text with the original keywords. + This function will be called in `self.decode`. + + Args: + text (Union[str, List[str]]): The text to be processed. + + Returns: + Union[str, List[str]]: The processed text. + """ + if isinstance(text, list): + output = [] + for i in range(len(text)): + output.append(self.replace_text_with_placeholder_tokens(text[i])) + return output + + for placeholder_token, tokens in self.token_map.items(): + merged_tokens = " ".join(tokens) + if merged_tokens in text: + text = text.replace(merged_tokens, placeholder_token) + return text + + def __call__( + self, + text: Union[str, List[str]], + *args, + vector_shuffle: bool = False, + prop_tokens_to_load: float = 1.0, + **kwargs, + ): + """The call function of the wrapper. + + Args: + text (Union[str, List[str]]): The text to be tokenized. + vector_shuffle (bool, optional): Whether to shuffle the vectors. + Defaults to False. + prop_tokens_to_load (float, optional): The proportion of tokens to + be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0 + *args, **kwargs: The arguments for `self.wrapped.__call__`. + """ + replaced_text = self.replace_placeholder_tokens_in_text( + text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load + ) + + return self.wrapped.__call__(replaced_text, *args, **kwargs) + + def encode(self, text: Union[str, List[str]], *args, **kwargs): + """Encode the passed text to token index. + + Args: + text (Union[str, List[str]]): The text to be encode. + *args, **kwargs: The arguments for `self.wrapped.__call__`. + """ + replaced_text = self.replace_placeholder_tokens_in_text(text) + return self.wrapped(replaced_text, *args, **kwargs) + + def decode( + self, token_ids, return_raw: bool = False, *args, **kwargs + ) -> Union[str, List[str]]: + """Decode the token index to text. + + Args: + token_ids: The token index to be decoded. + return_raw: Whether keep the placeholder token in the text. + Defaults to False. + *args, **kwargs: The arguments for `self.wrapped.decode`. + + Returns: + Union[str, List[str]]: The decoded text. + """ + text = self.wrapped.decode(token_ids, *args, **kwargs) + if return_raw: + return text + replaced_text = self.replace_text_with_placeholder_tokens(text) + return replaced_text + + +class EmbeddingLayerWithFixes(nn.Module): + """The revised embedding layer to support external embeddings. This design + of this class is inspired by https://github.com/AUTOMATIC1111/stable- + diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi + jack.py#L224 # noqa. + + Args: + wrapped (nn.Emebdding): The embedding layer to be wrapped. + external_embeddings (Union[dict, List[dict]], optional): The external + embeddings added to this layer. Defaults to None. + """ + + def __init__( + self, + wrapped: nn.Embedding, + external_embeddings: Optional[Union[dict, List[dict]]] = None, + ): + super().__init__() + self.wrapped = wrapped + self.num_embeddings = wrapped.weight.shape[0] + + self.external_embeddings = [] + if external_embeddings: + self.add_embeddings(external_embeddings) + + self.trainable_embeddings = nn.ParameterDict() + + @property + def weight(self): + """Get the weight of wrapped embedding layer.""" + return self.wrapped.weight + + def check_duplicate_names(self, embeddings: List[dict]): + """Check whether duplicate names exist in list of 'external + embeddings'. + + Args: + embeddings (List[dict]): A list of embedding to be check. + """ + names = [emb["name"] for emb in embeddings] + assert len(names) == len(set(names)), ( + "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'" + ) + + def check_ids_overlap(self, embeddings): + """Check whether overlap exist in token ids of 'external_embeddings'. + + Args: + embeddings (List[dict]): A list of embedding to be check. + """ + ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings] + ids_range.sort() # sort by 'start' + # check if 'end' has overlapping + for idx in range(len(ids_range) - 1): + name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1] + assert ids_range[idx][1] <= ids_range[idx + 1][0], ( + f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'." + ) + + def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]): + """Add external embeddings to this layer. + + Use case: + + >>> 1. Add token to tokenizer and get the token id. + >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32') + >>> # 'how much' in kiswahili + >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4) + >>> + >>> 2. Add external embeddings to the model. + >>> new_embedding = { + >>> 'name': 'ngapi', # 'how much' in kiswahili + >>> 'embedding': torch.ones(1, 15) * 4, + >>> 'start': tokenizer.get_token_info('kwaheri')['start'], + >>> 'end': tokenizer.get_token_info('kwaheri')['end'], + >>> 'trainable': False # if True, will registry as a parameter + >>> } + >>> embedding_layer = nn.Embedding(10, 15) + >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer) + >>> embedding_layer_wrapper.add_embeddings(new_embedding) + >>> + >>> 3. Forward tokenizer and embedding layer! + >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?'] + >>> input_ids = tokenizer( + >>> input_text, padding='max_length', truncation=True, + >>> return_tensors='pt')['input_ids'] + >>> out_feat = embedding_layer_wrapper(input_ids) + >>> + >>> 4. Let's validate the result! + >>> assert (out_feat[0, 3: 7] == 2.3).all() + >>> assert (out_feat[2, 5: 9] == 2.3).all() + + Args: + embeddings (Union[dict, list[dict]]): The external embeddings to + be added. Each dict must contain the following 4 fields: 'name' + (the name of this embedding), 'embedding' (the embedding + tensor), 'start' (the start token id of this embedding), 'end' + (the end token id of this embedding). For example: + `{name: NAME, start: START, end: END, embedding: torch.Tensor}` + """ + if isinstance(embeddings, dict): + embeddings = [embeddings] + + self.external_embeddings += embeddings + self.check_duplicate_names(self.external_embeddings) + self.check_ids_overlap(self.external_embeddings) + + # set for trainable + added_trainable_emb_info = [] + for embedding in embeddings: + trainable = embedding.get("trainable", False) + if trainable: + name = embedding["name"] + embedding["embedding"] = torch.nn.Parameter(embedding["embedding"]) + self.trainable_embeddings[name] = embedding["embedding"] + added_trainable_emb_info.append(name) + + added_emb_info = [emb["name"] for emb in embeddings] + added_emb_info = ", ".join(added_emb_info) + print(f"Successfully add external embeddings: {added_emb_info}.", "current") + + if added_trainable_emb_info: + added_trainable_emb_info = ", ".join(added_trainable_emb_info) + print( + "Successfully add trainable external embeddings: " + f"{added_trainable_emb_info}", + "current", + ) + + def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Replace external input ids to 0. + + Args: + input_ids (torch.Tensor): The input ids to be replaced. + + Returns: + torch.Tensor: The replaced input ids. + """ + input_ids_fwd = input_ids.clone() + input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0 + return input_ids_fwd + + def replace_embeddings( + self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict + ) -> torch.Tensor: + """Replace external embedding to the embedding layer. Noted that, in + this function we use `torch.cat` to avoid inplace modification. + + Args: + input_ids (torch.Tensor): The original token ids. Shape like + [LENGTH, ]. + embedding (torch.Tensor): The embedding of token ids after + `replace_input_ids` function. + external_embedding (dict): The external embedding to be replaced. + + Returns: + torch.Tensor: The replaced embedding. + """ + new_embedding = [] + + name = external_embedding["name"] + start = external_embedding["start"] + end = external_embedding["end"] + target_ids_to_replace = [i for i in range(start, end)] + ext_emb = external_embedding["embedding"] + + # do not need to replace + if not (input_ids == start).any(): + return embedding + + # start replace + s_idx, e_idx = 0, 0 + while e_idx < len(input_ids): + if input_ids[e_idx] == start: + if e_idx != 0: + # add embedding do not need to replace + new_embedding.append(embedding[s_idx:e_idx]) + + # check if the next embedding need to replace is valid + actually_ids_to_replace = [ + int(i) for i in input_ids[e_idx : e_idx + end - start] + ] + assert actually_ids_to_replace == target_ids_to_replace, ( + f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. " + f"Expect '{target_ids_to_replace}' for embedding " + f"'{name}' but found '{actually_ids_to_replace}'." + ) + + new_embedding.append(ext_emb) + + s_idx = e_idx + end - start + e_idx = s_idx + 1 + else: + e_idx += 1 + + if e_idx == len(input_ids): + new_embedding.append(embedding[s_idx:e_idx]) + + return torch.cat(new_embedding, dim=0) + + def forward( + self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None + ): + """The forward function. + + Args: + input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or + [LENGTH, ]. + external_embeddings (Optional[List[dict]]): The external + embeddings. If not passed, only `self.external_embeddings` + will be used. Defaults to None. + + input_ids: shape like [bz, LENGTH] or [LENGTH]. + """ + assert input_ids.ndim in [1, 2] + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + if external_embeddings is None and not self.external_embeddings: + return self.wrapped(input_ids) + + input_ids_fwd = self.replace_input_ids(input_ids) + inputs_embeds = self.wrapped(input_ids_fwd) + + vecs = [] + + if external_embeddings is None: + external_embeddings = [] + elif isinstance(external_embeddings, dict): + external_embeddings = [external_embeddings] + embeddings = self.external_embeddings + external_embeddings + + for input_id, embedding in zip(input_ids, inputs_embeds): + new_embedding = embedding + for external_embedding in embeddings: + new_embedding = self.replace_embeddings( + input_id, new_embedding, external_embedding + ) + vecs.append(new_embedding) + + return torch.stack(vecs) + + +def add_tokens( + tokenizer, + text_encoder, + placeholder_tokens: list, + initialize_tokens: list = None, + num_vectors_per_token: int = 1, +): + """Add token for training. + + # TODO: support add tokens as dict, then we can load pretrained tokens. + """ + if initialize_tokens is not None: + assert len(initialize_tokens) == len( + placeholder_tokens + ), "placeholder_token should be the same length as initialize_token" + for ii in range(len(placeholder_tokens)): + tokenizer.add_placeholder_token( + placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token + ) + + # text_encoder.set_embedding_layer() + embedding_layer = text_encoder.text_model.embeddings.token_embedding + text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes( + embedding_layer + ) + embedding_layer = text_encoder.text_model.embeddings.token_embedding + + assert embedding_layer is not None, ( + "Do not support get embedding layer for current text encoder. " + "Please check your configuration." + ) + initialize_embedding = [] + if initialize_tokens is not None: + for ii in range(len(placeholder_tokens)): + init_id = tokenizer(initialize_tokens[ii]).input_ids[1] + temp_embedding = embedding_layer.weight[init_id] + initialize_embedding.append( + temp_embedding[None, ...].repeat(num_vectors_per_token, 1) + ) + else: + for ii in range(len(placeholder_tokens)): + init_id = tokenizer("a").input_ids[1] + temp_embedding = embedding_layer.weight[init_id] + len_emb = temp_embedding.shape[0] + init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0 + initialize_embedding.append(init_weight) + + # initialize_embedding = torch.cat(initialize_embedding,dim=0) + + token_info_all = [] + for ii in range(len(placeholder_tokens)): + token_info = tokenizer.get_token_info(placeholder_tokens[ii]) + token_info["embedding"] = initialize_embedding[ii] + token_info["trainable"] = True + token_info_all.append(token_info) + embedding_layer.add_embeddings(token_info_all) diff --git a/custom-demo/back-end/model/sd.py b/custom-demo/back-end/model/sd.py new file mode 100644 index 0000000..8f42fff --- /dev/null +++ b/custom-demo/back-end/model/sd.py @@ -0,0 +1,129 @@ +import PIL.Image +import cv2 +import torch +from loguru import logger + +from .base import DiffusionInpaintModel +from .helper.cpu_text_encoder import CPUTextEncoderWrapper +from .original_sd_configs import get_config_files +from .utils import ( + handle_from_pretrained_exceptions, + get_torch_dtype, + enable_low_mem, + is_local_files_only, +) +from iopaint.schema import InpaintRequest, ModelType + + +class SD(DiffusionInpaintModel): + pad_mod = 8 + min_size = 512 + lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" + + def init_model(self, device: torch.device, **kwargs): + from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + + model_kwargs = { + **kwargs.get("pipe_components", {}), + "local_files_only": is_local_files_only(**kwargs), + } + disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get( + "cpu_offload", False + ) + if disable_nsfw_checker: + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + if self.model_info.is_single_file_diffusers: + if self.model_info.model_type == ModelType.DIFFUSERS_SD: + model_kwargs["num_in_channels"] = 4 + else: + model_kwargs["num_in_channels"] = 9 + + self.model = StableDiffusionInpaintPipeline.from_single_file( + self.model_id_or_path, + torch_dtype=torch_dtype, + load_safety_checker=not disable_nsfw_checker, + config_files=get_config_files(), + **model_kwargs, + ) + else: + self.model = handle_from_pretrained_exceptions( + StableDiffusionInpaintPipeline.from_pretrained, + pretrained_model_name_or_path=self.model_id_or_path, + variant="fp16", + torch_dtype=torch_dtype, + **model_kwargs, + ) + + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + self.set_scheduler(config) + + img_h, img_w = image.shape[:2] + + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + strength=config.sd_strength, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback_on_step_end=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +class SD15(SD): + name = "runwayml/stable-diffusion-inpainting" + model_id_or_path = "runwayml/stable-diffusion-inpainting" + + +class Anything4(SD): + name = "Sanster/anything-4.0-inpainting" + model_id_or_path = "Sanster/anything-4.0-inpainting" + + +class RealisticVision14(SD): + name = "Sanster/Realistic_Vision_V1.4-inpainting" + model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting" + + +class SD2(SD): + name = "stabilityai/stable-diffusion-2-inpainting" + model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" diff --git a/custom-demo/back-end/model/sdxl.py b/custom-demo/back-end/model/sdxl.py new file mode 100644 index 0000000..29312b1 --- /dev/null +++ b/custom-demo/back-end/model/sdxl.py @@ -0,0 +1,110 @@ +import os + +import PIL.Image +import cv2 +import torch +from diffusers import AutoencoderKL +from loguru import logger + +from iopaint.schema import InpaintRequest, ModelType + +from .base import DiffusionInpaintModel +from .helper.cpu_text_encoder import CPUTextEncoderWrapper +from .original_sd_configs import get_config_files +from .utils import ( + handle_from_pretrained_exceptions, + get_torch_dtype, + enable_low_mem, + is_local_files_only, +) + + +class SDXL(DiffusionInpaintModel): + name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" + pad_mod = 8 + min_size = 512 + lcm_lora_id = "latent-consistency/lcm-lora-sdxl" + model_id_or_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" + + def init_model(self, device: torch.device, **kwargs): + from diffusers.pipelines import StableDiffusionXLInpaintPipeline + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + + if self.model_info.model_type == ModelType.DIFFUSERS_SDXL: + num_in_channels = 4 + else: + num_in_channels = 9 + + if os.path.isfile(self.model_id_or_path): + self.model = StableDiffusionXLInpaintPipeline.from_single_file( + self.model_id_or_path, + torch_dtype=torch_dtype, + num_in_channels=num_in_channels, + load_safety_checker=False, + config_files=get_config_files() + ) + else: + model_kwargs = { + **kwargs.get("pipe_components", {}), + "local_files_only": is_local_files_only(**kwargs), + } + if "vae" not in model_kwargs: + vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype + ) + model_kwargs["vae"] = vae + self.model = handle_from_pretrained_exceptions( + StableDiffusionXLInpaintPipeline.from_pretrained, + pretrained_model_name_or_path=self.model_id_or_path, + torch_dtype=torch_dtype, + variant="fp16", + **model_kwargs + ) + + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + self.model.text_encoder_2 = CPUTextEncoderWrapper( + self.model.text_encoder_2, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + self.set_scheduler(config) + + img_h, img_w = image.shape[:2] + + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + strength=0.999 if config.sd_strength == 1.0 else config.sd_strength, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback_on_step_end=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/custom-demo/back-end/model/utils.py b/custom-demo/back-end/model/utils.py new file mode 100644 index 0000000..73465e8 --- /dev/null +++ b/custom-demo/back-end/model/utils.py @@ -0,0 +1,1033 @@ +import gc +import math +import random +import traceback +from typing import Any + +import torch +import numpy as np +import collections +from itertools import repeat + +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + UniPCMultistepScheduler, + LCMScheduler, + DPMSolverSinglestepScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, +) +from loguru import logger + +from iopaint.schema import SDSampler +from torch import conv2d, conv_transpose2d + + +def make_beta_schedule( + device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ).to(device) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2).to(device) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=device) + + args = timesteps[:, None].float() * freqs[None] + + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +###### MAT and FcF ####### + + +def normalize_2nd_moment(x, dim=1): + return ( + x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt() + ) + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + + +def bias_act( + x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref" +): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ["ref", "cuda"] + return _bias_act_ref( + x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp + ) + + +def _get_filter_size(f): + if f is None: + return 1, 1 + + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + + fw = int(fw) + fh = int(fh) + assert fw >= 1 and fh >= 1 + return fw, fh + + +def _get_weight_shape(w): + shape = [int(sz) for sz in w.shape] + return shape + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def setup_filter( + f, + device=torch.device("cpu"), + normalize=True, + flip_filter=False, + gain=1, + separable=None, +): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = f.ndim == 1 and f.numel() >= 8 + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + +activation_funcs = { + "linear": EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + cuda_idx=1, + ref="", + has_2nd_grad=False, + ), + "relu": EasyDict( + func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=2, + ref="y", + has_2nd_grad=False, + ), + "lrelu": EasyDict( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, + def_gain=np.sqrt(2), + cuda_idx=3, + ref="y", + has_2nd_grad=False, + ), + "tanh": EasyDict( + func=lambda x, **_: torch.tanh(x), + def_alpha=0, + def_gain=1, + cuda_idx=4, + ref="y", + has_2nd_grad=True, + ), + "sigmoid": EasyDict( + func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, + def_gain=1, + cuda_idx=5, + ref="y", + has_2nd_grad=True, + ), + "elu": EasyDict( + func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, + def_gain=1, + cuda_idx=6, + ref="y", + has_2nd_grad=True, + ), + "selu": EasyDict( + func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, + def_gain=1, + cuda_idx=7, + ref="y", + has_2nd_grad=True, + ), + "softplus": EasyDict( + func=lambda x, **_: torch.nn.functional.softplus(x), + def_alpha=0, + def_gain=1, + cuda_idx=8, + ref="y", + has_2nd_grad=True, + ), + "swish": EasyDict( + func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=9, + ref="x", + has_2nd_grad=True, + ), +} + + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # assert isinstance(x, torch.Tensor) + # assert impl in ['ref', 'cuda'] + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) + + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + # upx, upy = _parse_scaling(up) + # downx, downy = _parse_scaling(down) + + upx, upy = up, up + downx, downy = down, down + + # padx0, padx1, pady0, pady1 = _parse_padding(padding) + padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad( + x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] + ) + x = x[ + :, + :, + max(-pady0, 0) : x.shape[2] - max(-pady1, 0), + max(-padx0, 0) : x.shape[3] - max(-padx1, 0), + ] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + # padx0, padx1, pady0, pady1 = _parse_padding(padding) + padx0, padx1, pady0, pady1 = padding, padding, padding, padding + + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d( + x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl + ) + + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + # upx, upy = up, up + padx0, padx1, pady0, pady1 = _parse_padding(padding) + # padx0, padx1, pady0, pady1 = padding, padding, padding, padding + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d( + x, + f, + up=up, + padding=p, + flip_filter=flip_filter, + gain=gain * upx * upy, + impl=impl, + ) + + +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + G = ( + torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) + if self.group_size is not None + else N + ) + F = self.num_channels + c = C // F + + y = x.reshape( + G, -1, F, c, H, W + ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + +class FullyConnectedLayer(torch.nn.Module): + def __init__( + self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn([out_features, in_features]) / lr_multiplier + ) + self.bias = ( + torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) + if bias + else None + ) + self.activation = activation + + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight * self.weight_gain + b = self.bias + if b is not None and self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == "linear" and b is not None: + # out = torch.addmm(b.unsqueeze(0), x, w.t()) + x = x.matmul(w.t()) + out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) + else: + x = x.matmul(w.t()) + out = bias_act(x, b, act=self.activation, dim=x.ndim - 1) + return out + + +def _conv2d_wrapper( + x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True +): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.""" + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if ( + not flip_weight + ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if ( + kw == 1 + and kh == 1 + and stride == 1 + and padding in [0, [0, 0], (0, 0)] + and not transpose + ): + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape( + [in_shape[0], in_channels_per_group, -1] + ) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv_transpose2d if transpose else conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + + +def conv2d_resample( + x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False +): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + # px0, px1, py0, py1 = _parse_padding(padding) + px0, px1, py0, py1 = padding, padding, padding, padding + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d( + x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter + ) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d( + x=x, + f=f, + up=up, + padding=[px0, px1, py0, py1], + gain=up**2, + flip_filter=flip_filter, + ) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = _conv2d_wrapper( + x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight + ) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape( + groups * in_channels_per_group, out_channels // groups, kh, kw + ) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper( + x=x, + w=w, + stride=up, + padding=[pyt, pxt], + groups=groups, + transpose=True, + flip_weight=(not flip_weight), + ) + x = upfirdn2d( + x=x, + f=f, + padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], + gain=up**2, + flip_filter=flip_filter, + ) + if down > 1: + x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper( + x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight + ) + + # Fallback: Generic reference implementation. + x = upfirdn2d( + x=x, + f=(f if up > 1 else None), + up=up, + padding=[px0, px1, py0, py1], + gain=up**2, + flip_filter=flip_filter, + ) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + +class Conv2dLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.activation = activation + self.up = up + self.down = down + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) + self.act_gain = activation_funcs[activation].def_gain + + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer("weight", weight) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + x = conv2d_resample( + x=x, + w=w, + f=self.resample_filter, + up=self.up, + down=self.down, + padding=self.padding, + ) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + out = bias_act( + x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp + ) + return out + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + gc.collect() + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_scheduler(sd_sampler, scheduler_config): + # https://github.com/huggingface/diffusers/issues/4167 + keys_to_pop = ["use_karras_sigmas", "algorithm_type"] + scheduler_config = dict(scheduler_config) + for it in keys_to_pop: + scheduler_config.pop(it, None) + + # fmt: off + samplers = { + SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler], + SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)], + SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")], + SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)], + SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler], + SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)], + SDSampler.dpm2: [KDPM2DiscreteScheduler], + SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)], + SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler], + SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)], + SDSampler.euler: [EulerDiscreteScheduler], + SDSampler.euler_a: [EulerAncestralDiscreteScheduler], + SDSampler.heun: [HeunDiscreteScheduler], + SDSampler.lms: [LMSDiscreteScheduler], + SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)], + SDSampler.ddim: [DDIMScheduler], + SDSampler.pndm: [PNDMScheduler], + SDSampler.uni_pc: [UniPCMultistepScheduler], + SDSampler.lcm: [LCMScheduler], + } + # fmt: on + if sd_sampler in samplers: + if len(samplers[sd_sampler]) == 2: + scheduler_cls, kwargs = samplers[sd_sampler] + else: + scheduler_cls, kwargs = samplers[sd_sampler][0], {} + return scheduler_cls.from_config(scheduler_config, **kwargs) + else: + raise ValueError(sd_sampler) + + +def is_local_files_only(**kwargs) -> bool: + from huggingface_hub.constants import HF_HUB_OFFLINE + + return HF_HUB_OFFLINE or kwargs.get("local_files_only", False) + + +def handle_from_pretrained_exceptions(func, **kwargs): + try: + return func(**kwargs) + except ValueError as e: + if "You are trying to load the model files of the `variant=fp16`" in str(e): + logger.info("variant=fp16 not found, try revision=fp16") + try: + return func(**{**kwargs, "variant": None, "revision": "fp16"}) + except Exception as e: + logger.info("revision=fp16 not found, try revision=main") + return func(**{**kwargs, "variant": None, "revision": "main"}) + raise e + except OSError as e: + previous_traceback = traceback.format_exc() + if "RevisionNotFoundError: 404 Client Error." in previous_traceback: + logger.info("revision=fp16 not found, try revision=main") + return func(**{**kwargs, "variant": None, "revision": "main"}) + elif "Max retries exceeded" in previous_traceback: + logger.exception( + "Fetching model from HuggingFace failed. " + "If this is your first time downloading the model, you may need to set up proxy in terminal." + "If the model has already been downloaded, you can add --local-files-only when starting." + ) + exit(-1) + raise e + except Exception as e: + raise e + + +def get_torch_dtype(device, no_half: bool): + device = str(device) + use_fp16 = not no_half + use_gpu = device == "cuda" + # https://github.com/huggingface/diffusers/issues/4480 + # pipe.enable_attention_slicing and float16 will cause black output on mps + # if device in ["cuda", "mps"] and use_fp16: + if device in ["cuda"] and use_fp16: + return use_gpu, torch.float16 + return use_gpu, torch.float32 + + +def enable_low_mem(pipe, enable: bool): + if torch.backends.mps.is_available(): + # https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing + # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. + if enable: + pipe.enable_attention_slicing("max") + else: + # https://huggingface.co/docs/diffusers/optimization/mps + # Devices with less than 64GB of memory are recommended to use enable_attention_slicing + pipe.enable_attention_slicing() + + if enable: + pipe.vae.enable_tiling() diff --git a/custom-demo/back-end/model/zits.py b/custom-demo/back-end/model/zits.py new file mode 100644 index 0000000..d58ac01 --- /dev/null +++ b/custom-demo/back-end/model/zits.py @@ -0,0 +1,476 @@ +import os +import time + +import cv2 +import torch +import torch.nn.functional as F + +from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model +from iopaint.schema import InpaintRequest +import numpy as np + +from .base import InpaintModel + +ZITS_INPAINT_MODEL_URL = os.environ.get( + "ZITS_INPAINT_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt", +) +ZITS_INPAINT_MODEL_MD5 = os.environ.get( + "ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154" +) + +ZITS_EDGE_LINE_MODEL_URL = os.environ.get( + "ZITS_EDGE_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt", +) +ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get( + "ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f" +) + +ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get( + "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt", +) +ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get( + "ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927" +) + +ZITS_WIRE_FRAME_MODEL_URL = os.environ.get( + "ZITS_WIRE_FRAME_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt", +) +ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get( + "ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b" +) + + +def resize(img, height, width, center_crop=False): + imgh, imgw = img.shape[0:2] + + if center_crop and imgh != imgw: + # center crop + side = np.minimum(imgh, imgw) + j = (imgh - side) // 2 + i = (imgw - side) // 2 + img = img[j : j + side, i : i + side, ...] + + if imgh > height and imgw > width: + inter = cv2.INTER_AREA + else: + inter = cv2.INTER_LINEAR + img = cv2.resize(img, (height, width), interpolation=inter) + + return img + + +def to_tensor(img, scale=True, norm=False): + if img.ndim == 2: + img = img[:, :, np.newaxis] + c = img.shape[-1] + + if scale: + img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255) + else: + img_t = torch.from_numpy(img).permute(2, 0, 1).float() + + if norm: + mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) + std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) + img_t = (img_t - mean) / std + return img_t + + +def load_masked_position_encoding(mask): + ones_filter = np.ones((3, 3), dtype=np.float32) + d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32) + d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32) + d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32) + d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32) + str_size = 256 + pos_num = 128 + + ori_mask = mask.copy() + ori_h, ori_w = ori_mask.shape[0:2] + ori_mask = ori_mask / 255 + mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA) + mask[mask > 0] = 255 + h, w = mask.shape[0:2] + mask3 = mask.copy() + mask3 = 1.0 - (mask3 / 255.0) + pos = np.zeros((h, w), dtype=np.int32) + direct = np.zeros((h, w, 4), dtype=np.int32) + i = 0 + while np.sum(1 - mask3) > 0: + i += 1 + mask3_ = cv2.filter2D(mask3, -1, ones_filter) + mask3_[mask3_ > 0] = 1 + sub_mask = mask3_ - mask3 + pos[sub_mask == 1] = i + + m = cv2.filter2D(mask3, -1, d_filter1) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 0] = 1 + + m = cv2.filter2D(mask3, -1, d_filter2) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 1] = 1 + + m = cv2.filter2D(mask3, -1, d_filter3) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 2] = 1 + + m = cv2.filter2D(mask3, -1, d_filter4) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 3] = 1 + + mask3 = mask3_ + + abs_pos = pos.copy() + rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1 + rel_pos = (rel_pos * pos_num).astype(np.int32) + rel_pos = np.clip(rel_pos, 0, pos_num - 1) + + if ori_w != w or ori_h != h: + rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) + rel_pos[ori_mask == 0] = 0 + direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) + direct[ori_mask == 0, :] = 0 + + return rel_pos, abs_pos, direct + + +def load_image(img, mask, device, sigma256=3.0): + """ + Args: + img: [H, W, C] RGB + mask: [H, W] 255 为 masks 区域 + sigma256: + + Returns: + + """ + h, w, _ = img.shape + imgh, imgw = img.shape[0:2] + img_256 = resize(img, 256, 256) + + mask = (mask > 127).astype(np.uint8) * 255 + mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA) + mask_256[mask_256 > 0] = 255 + + mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA) + mask_512[mask_512 > 0] = 255 + + # original skimage implemention + # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny + # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max. + # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max. + + try: + import skimage + + gray_256 = skimage.color.rgb2gray(img_256) + edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float) + # cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8)) + # cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8)) + except: + gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY) + gray_256_blured = cv2.GaussianBlur( + gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256 + ) + edge_256 = cv2.Canny( + gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2) + ) + + # cv2.imwrite("opencv_edge.jpg", edge_256) + + # line + img_512 = resize(img, 512, 512) + + rel_pos, abs_pos, direct = load_masked_position_encoding(mask) + + batch = dict() + batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device) + batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device) + batch["masks"] = to_tensor(mask).unsqueeze(0).to(device) + batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device) + batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device) + batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device) + batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device) + batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device) + batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device) + batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device) + batch["h"] = imgh + batch["w"] = imgw + + return batch + + +def to_device(data, device): + if isinstance(data, torch.Tensor): + return data.to(device) + if isinstance(data, dict): + for key in data: + if isinstance(data[key], torch.Tensor): + data[key] = data[key].to(device) + return data + if isinstance(data, list): + return [to_device(d, device) for d in data] + + +class ZITS(InpaintModel): + name = "zits" + min_size = 256 + pad_mod = 32 + pad_to_square = True + is_erase_model = True + + def __init__(self, device, **kwargs): + """ + + Args: + device: + """ + super().__init__(device) + self.device = device + self.sample_edge_line_iterations = 1 + + def init_model(self, device, **kwargs): + self.wireframe = load_jit_model( + ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5 + ) + self.edge_line = load_jit_model( + ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5 + ) + self.structure_upsample = load_jit_model( + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 + ) + self.inpaint = load_jit_model( + ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5 + ) + + @staticmethod + def download(): + download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5) + download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5) + download_model( + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 + ) + download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL), + get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL), + get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL), + get_cache_path_by_url(ZITS_INPAINT_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def wireframe_edge_and_line(self, items, enable: bool): + # 最终向 items 中添加 edge 和 line key + if not enable: + items["edge"] = torch.zeros_like(items["masks"]) + items["line"] = torch.zeros_like(items["masks"]) + return + + start = time.time() + try: + line_256 = self.wireframe_forward( + items["img_512"], + h=256, + w=256, + masks=items["mask_512"], + mask_th=0.85, + ) + except: + line_256 = torch.zeros_like(items["mask_256"]) + + print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms") + + # np_line = (line[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line.jpg", np_line) + + start = time.time() + edge_pred, line_pred = self.sample_edge_line_logits( + context=[items["img_256"], items["edge_256"], line_256], + mask=items["mask_256"].clone(), + iterations=self.sample_edge_line_iterations, + add_v=0.05, + mul_v=4, + ) + print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms") + + # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("edge_pred.jpg", np_edge_pred) + # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line_pred.jpg", np_line_pred) + # exit() + + input_size = min(items["h"], items["w"]) + if input_size != 256 and input_size > 256: + while edge_pred.shape[2] < input_size: + edge_pred = self.structure_upsample(edge_pred) + edge_pred = torch.sigmoid((edge_pred + 2) * 2) + + line_pred = self.structure_upsample(line_pred) + line_pred = torch.sigmoid((line_pred + 2) * 2) + + edge_pred = F.interpolate( + edge_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + line_pred = F.interpolate( + line_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + + # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred) + # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line_pred_upsample.jpg", np_line_pred) + # exit() + + items["edge"] = edge_pred.detach() + items["line"] = line_pred.detach() + + @torch.no_grad() + def forward(self, image, mask, config: InpaintRequest): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] + return: BGR IMAGE + """ + mask = mask[:, :, 0] + items = load_image(image, mask, device=self.device) + + self.wireframe_edge_and_line(items, config.zits_wireframe) + + inpainted_image = self.inpaint( + items["images"], + items["masks"], + items["edge"], + items["line"], + items["rel_pos"], + items["direct"], + ) + + inpainted_image = inpainted_image * 255.0 + inpainted_image = ( + inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) + ) + inpainted_image = inpainted_image[:, :, ::-1] + + # cv2.imwrite("inpainted.jpg", inpainted_image) + # exit() + + return inpainted_image + + def wireframe_forward(self, images, h, w, masks, mask_th=0.925): + lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1) + lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1) + images = images * 255.0 + # the masks value of lcnn is 127.5 + masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5 + masked_images = (masked_images - lcnn_mean) / lcnn_std + + def to_int(x): + return tuple(map(int, x)) + + lines_tensor = [] + lmap = np.zeros((h, w)) + + output_masked = self.wireframe(masked_images) + + output_masked = to_device(output_masked, "cpu") + if output_masked["num_proposals"] == 0: + lines_masked = [] + scores_masked = [] + else: + lines_masked = output_masked["lines_pred"].numpy() + lines_masked = [ + [line[1] * h, line[0] * w, line[3] * h, line[2] * w] + for line in lines_masked + ] + scores_masked = output_masked["lines_score"].numpy() + + for line, score in zip(lines_masked, scores_masked): + if score > mask_th: + try: + import skimage + + rr, cc, value = skimage.draw.line_aa( + *to_int(line[0:2]), *to_int(line[2:4]) + ) + lmap[rr, cc] = np.maximum(lmap[rr, cc], value) + except: + cv2.line( + lmap, + to_int(line[0:2][::-1]), + to_int(line[2:4][::-1]), + (1, 1, 1), + 1, + cv2.LINE_AA, + ) + + lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8) + lines_tensor.append(to_tensor(lmap).unsqueeze(0)) + + lines_tensor = torch.cat(lines_tensor, dim=0) + return lines_tensor.detach().to(self.device) + + def sample_edge_line_logits( + self, context, mask=None, iterations=1, add_v=0, mul_v=4 + ): + [img, edge, line] = context + + img = img * (1 - mask) + edge = edge * (1 - mask) + line = line * (1 - mask) + + for i in range(iterations): + edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask) + + edge_pred = torch.sigmoid(edge_logits) + line_pred = torch.sigmoid((line_logits + add_v) * mul_v) + edge = edge + edge_pred * mask + edge[edge >= 0.25] = 1 + edge[edge < 0.25] = 0 + line = line + line_pred * mask + + b, _, h, w = edge_pred.shape + edge_pred = edge_pred.reshape(b, -1, 1) + line_pred = line_pred.reshape(b, -1, 1) + mask = mask.reshape(b, -1) + + edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1) + line_probs = torch.cat([1 - line_pred, line_pred], dim=-1) + edge_probs[:, :, 1] += 0.5 + line_probs[:, :, 1] += 0.5 + edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100) + line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100) + + indices = torch.sort( + edge_max_probs + line_max_probs, dim=-1, descending=True + )[1] + + for ii in range(b): + keep = int((i + 1) / iterations * torch.sum(mask[ii, ...])) + + assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!" + mask[ii][indices[ii, :keep]] = 0 + + mask = mask.reshape(b, 1, h, w) + edge = edge * (1 - mask) + line = line * (1 - mask) + + edge, line = edge.to(torch.float32), line.to(torch.float32) + return edge, line diff --git a/custom-demo/back-end/model_manager.py b/custom-demo/back-end/model_manager.py new file mode 100644 index 0000000..15cf32b --- /dev/null +++ b/custom-demo/back-end/model_manager.py @@ -0,0 +1,196 @@ +from typing import List, Dict + +import torch +from loguru import logger +import numpy as np + +from iopaint.download import scan_models +from iopaint.helper import switch_mps_device +from iopaint.model import models, ControlNet, SD, SDXL +from iopaint.model.utils import torch_gc, is_local_files_only +from iopaint.schema import InpaintRequest, ModelInfo, ModelType + + +class ModelManager: + def __init__(self, name: str, device: torch.device, **kwargs): + self.name = name + self.device = device + self.kwargs = kwargs + self.available_models: Dict[str, ModelInfo] = {} + self.scan_models() + + self.enable_controlnet = kwargs.get("enable_controlnet", False) + controlnet_method = kwargs.get("controlnet_method", None) + if ( + controlnet_method is None + and name in self.available_models + and self.available_models[name].support_controlnet + ): + controlnet_method = self.available_models[name].controlnets[0] + self.controlnet_method = controlnet_method + self.model = self.init_model(name, device, **kwargs) + + @property + def current_model(self) -> ModelInfo: + return self.available_models[self.name] + + def init_model(self, name: str, device, **kwargs): + logger.info(f"Loading model: {name}") + if name not in self.available_models: + raise NotImplementedError( + f"Unsupported model: {name}. Available models: {list(self.available_models.keys())}" + ) + + model_info = self.available_models[name] + kwargs = { + **kwargs, + "model_info": model_info, + "enable_controlnet": self.enable_controlnet, + "controlnet_method": self.controlnet_method, + } + + if model_info.support_controlnet and self.enable_controlnet: + return ControlNet(device, **kwargs) + elif model_info.name in models: + return models[name](device, **kwargs) + else: + if model_info.model_type in [ + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SD, + ]: + return SD(device, **kwargs) + + if model_info.model_type in [ + ModelType.DIFFUSERS_SDXL_INPAINT, + ModelType.DIFFUSERS_SDXL, + ]: + return SDXL(device, **kwargs) + + raise NotImplementedError(f"Unsupported model: {name}") + + @torch.inference_mode() + def __call__(self, image, mask, config: InpaintRequest): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + config: + + Returns: + BGR image + """ + self.switch_controlnet_method(config) + self.enable_disable_freeu(config) + self.enable_disable_lcm_lora(config) + return self.model(image, mask, config).astype(np.uint8) + + def scan_models(self) -> List[ModelInfo]: + available_models = scan_models() + self.available_models = {it.name: it for it in available_models} + return available_models + + def switch(self, new_name: str): + if new_name == self.name: + return + + old_name = self.name + old_controlnet_method = self.controlnet_method + self.name = new_name + + if ( + self.available_models[new_name].support_controlnet + and self.controlnet_method + not in self.available_models[new_name].controlnets + ): + self.controlnet_method = self.available_models[new_name].controlnets[0] + try: + # TODO: enable/disable controlnet without reload model + del self.model + torch_gc() + + self.model = self.init_model( + new_name, switch_mps_device(new_name, self.device), **self.kwargs + ) + except Exception as e: + self.name = old_name + self.controlnet_method = old_controlnet_method + logger.info(f"Switch model from {old_name} to {new_name} failed, rollback") + self.model = self.init_model( + old_name, switch_mps_device(old_name, self.device), **self.kwargs + ) + raise e + + def switch_controlnet_method(self, config): + if not self.available_models[self.name].support_controlnet: + return + + if ( + self.enable_controlnet + and config.controlnet_method + and self.controlnet_method != config.controlnet_method + ): + old_controlnet_method = self.controlnet_method + self.controlnet_method = config.controlnet_method + self.model.switch_controlnet_method(config.controlnet_method) + logger.info( + f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}" + ) + elif self.enable_controlnet != config.enable_controlnet: + self.enable_controlnet = config.enable_controlnet + self.controlnet_method = config.controlnet_method + + pipe_components = { + "vae": self.model.model.vae, + "text_encoder": self.model.model.text_encoder, + "unet": self.model.model.unet, + } + if hasattr(self.model.model, "text_encoder_2"): + pipe_components["text_encoder_2"] = self.model.model.text_encoder_2 + + self.model = self.init_model( + self.name, + switch_mps_device(self.name, self.device), + pipe_components=pipe_components, + **self.kwargs, + ) + if not config.enable_controlnet: + logger.info(f"Disable controlnet") + else: + logger.info(f"Enable controlnet: {config.controlnet_method}") + + def enable_disable_freeu(self, config: InpaintRequest): + if str(self.model.device) == "mps": + return + + if self.available_models[self.name].support_freeu: + if config.sd_freeu: + freeu_config = config.sd_freeu_config + self.model.model.enable_freeu( + s1=freeu_config.s1, + s2=freeu_config.s2, + b1=freeu_config.b1, + b2=freeu_config.b2, + ) + else: + self.model.model.disable_freeu() + + def enable_disable_lcm_lora(self, config: InpaintRequest): + if self.available_models[self.name].support_lcm_lora: + # TODO: change this if load other lora is supported + lcm_lora_loaded = bool(self.model.model.get_list_adapters()) + if config.sd_lcm_lora: + if not lcm_lora_loaded: + logger.info("Load LCM LORA") + self.model.model.load_lora_weights( + self.model.lcm_lora_id, + weight_name="pytorch_lora_weights.safetensors", + local_files_only=is_local_files_only(), + ) + else: + logger.info("Enable LCM LORA") + self.model.model.enable_lora() + else: + if lcm_lora_loaded: + logger.info("Disable LCM LORA") + self.model.model.disable_lora() diff --git a/custom-demo/back-end/nohup.out b/custom-demo/back-end/nohup.out new file mode 100644 index 0000000..0c40c25 --- /dev/null +++ b/custom-demo/back-end/nohup.out @@ -0,0 +1,15 @@ +2025-05-08 17:06:14.942 | INFO | iopaint.runtime:setup_model_dir:82 - Model directory: /root/.cache +- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39 +- Python version: 3.12.3 +- torch: 2.7.0 +- torchvision: 0.22.0 +- Pillow: 9.5.0 +- diffusers: 0.27.2 +- transformers: 4.48.3 +- opencv-python: 4.11.0.86 +- accelerate: 1.6.0 +- iopaint: N/A +- rembg: 2.0.65 +- realesrgan: N/A +- gfpgan: N/A + diff --git a/custom-demo/back-end/plugins/__init__.py b/custom-demo/back-end/plugins/__init__.py new file mode 100644 index 0000000..8128025 --- /dev/null +++ b/custom-demo/back-end/plugins/__init__.py @@ -0,0 +1,74 @@ +from typing import Dict + +from loguru import logger + +from .anime_seg import AnimeSeg +from .gfpgan_plugin import GFPGANPlugin +from .interactive_seg import InteractiveSeg +from .realesrgan import RealESRGANUpscaler +from .remove_bg import RemoveBG +from .restoreformer import RestoreFormerPlugin +from ..schema import InteractiveSegModel, Device, RealESRGANModel + + +def build_plugins( + enable_interactive_seg: bool, + interactive_seg_model: InteractiveSegModel, + interactive_seg_device: Device, + enable_remove_bg: bool, + remove_bg_model: str, + enable_anime_seg: bool, + enable_realesrgan: bool, + realesrgan_device: Device, + realesrgan_model: RealESRGANModel, + enable_gfpgan: bool, + gfpgan_device: Device, + enable_restoreformer: bool, + restoreformer_device: Device, + no_half: bool, +) -> Dict: + plugins = {} + if enable_interactive_seg: + logger.info(f"Initialize {InteractiveSeg.name} plugin") + plugins[InteractiveSeg.name] = InteractiveSeg( + interactive_seg_model, interactive_seg_device + ) + + if enable_remove_bg: + logger.info(f"Initialize {RemoveBG.name} plugin") + plugins[RemoveBG.name] = RemoveBG(remove_bg_model) + + if enable_anime_seg: + logger.info(f"Initialize {AnimeSeg.name} plugin") + plugins[AnimeSeg.name] = AnimeSeg() + + if enable_realesrgan: + logger.info( + f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}" + ) + plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( + realesrgan_model, + realesrgan_device, + no_half=no_half, + ) + + if enable_gfpgan: + logger.info(f"Initialize {GFPGANPlugin.name} plugin") + if enable_realesrgan: + logger.info("Use realesrgan as GFPGAN background upscaler") + else: + logger.info( + f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" + ) + plugins[GFPGANPlugin.name] = GFPGANPlugin( + gfpgan_device, + upscaler=plugins.get(RealESRGANUpscaler.name, None), + ) + + if enable_restoreformer: + logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") + plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( + restoreformer_device, + upscaler=plugins.get(RealESRGANUpscaler.name, None), + ) + return plugins diff --git a/custom-demo/back-end/plugins/anime_seg.py b/custom-demo/back-end/plugins/anime_seg.py new file mode 100644 index 0000000..286564b --- /dev/null +++ b/custom-demo/back-end/plugins/anime_seg.py @@ -0,0 +1,462 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from PIL import Image + +from iopaint.helper import load_model +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.schema import RunPluginRequest + + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False) + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class ISNetDIS(nn.Module): + def __init__(self, in_ch=3, out_ch=1): + super(ISNetDIS, self).__init__() + + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + return d1.sigmoid() + + +# 从小到大 +ANIME_SEG_MODELS = { + "url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth", + "md5": "5f25479076b73074730ab8de9e8f2051", +} + + +class AnimeSeg(BasePlugin): + # Model from: https://github.com/SkyTNT/anime-segmentation + name = "AnimeSeg" + support_gen_image = True + support_gen_mask = True + + def __init__(self): + super().__init__() + self.model = load_model( + ISNetDIS(), + ANIME_SEG_MODELS["url"], + "cpu", + ANIME_SEG_MODELS["md5"], + ) + + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + mask = self.forward(rgb_np_img) + mask = Image.fromarray(mask, mode="L") + h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1] + empty = Image.new("RGBA", (w0, h0), 0) + img = Image.fromarray(rgb_np_img) + cutout = Image.composite(img, empty, mask) + return np.asarray(cutout) + + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + return self.forward(rgb_np_img) + + @torch.inference_mode() + def forward(self, rgb_np_img): + s = 1024 + + h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1] + if h > w: + h, w = s, int(s * w / h) + else: + h, w = int(s * h / w), s + ph, pw = s - h, s - w + tmpImg = np.zeros([s, s, 3], dtype=np.float32) + tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = ( + cv2.resize(rgb_np_img, (w, h)) / 255 + ) + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor) + mask = self.model(tmpImg) + mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] + mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0)) + return (mask * 255).astype("uint8") diff --git a/custom-demo/back-end/plugins/base_plugin.py b/custom-demo/back-end/plugins/base_plugin.py new file mode 100644 index 0000000..1f8bddc --- /dev/null +++ b/custom-demo/back-end/plugins/base_plugin.py @@ -0,0 +1,30 @@ +from loguru import logger +import numpy as np + +from iopaint.schema import RunPluginRequest + + +class BasePlugin: + name: str + support_gen_image: bool = False + support_gen_mask: bool = False + + def __init__(self): + err_msg = self.check_dep() + if err_msg: + logger.error(err_msg) + exit(-1) + + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + # return RGBA np image or BGR np image + ... + + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + # return GRAY or BGR np image, 255 means foreground, 0 means background + ... + + def check_dep(self): + ... + + def switch_model(self, new_model_name: str): + ... diff --git a/custom-demo/back-end/plugins/briarmbg.py b/custom-demo/back-end/plugins/briarmbg.py new file mode 100644 index 0000000..880f530 --- /dev/null +++ b/custom-demo/back-end/plugins/briarmbg.py @@ -0,0 +1,512 @@ +# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +import numpy as np +from torchvision.transforms.functional import normalize + + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear") + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__( + self, + in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + ): + super(myrebnconv, self).__init__() + + self.conv = nn.Conv2d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self, x): + return self.rl(self.bn(self.conv(x))) + + +class BriaRMBG(nn.Module): + def __init__(self, in_ch=3, out_ch=1): + super(BriaRMBG, self).__init__() + + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + # hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6, x) + + return [ + F.sigmoid(d1), + F.sigmoid(d2), + F.sigmoid(d3), + F.sigmoid(d4), + F.sigmoid(d5), + F.sigmoid(d6), + ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] + + +def resize_image(image): + image = image.convert("RGB") + model_input_size = (1024, 1024) + image = image.resize(model_input_size, Image.BILINEAR) + return image + + +def create_briarmbg_session(): + from huggingface_hub import hf_hub_download + + net = BriaRMBG() + model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth") + net.load_state_dict(torch.load(model_path, map_location="cpu")) + net.eval() + return net + + +def briarmbg_process(bgr_np_image, session, only_mask=False): + # prepare input + orig_bgr_image = Image.fromarray(bgr_np_image) + w, h = orig_im_size = orig_bgr_image.size + image = resize_image(orig_bgr_image) + im_np = np.array(image) + im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) + im_tensor = torch.unsqueeze(im_tensor, 0) + im_tensor = torch.divide(im_tensor, 255.0) + im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) + # inference + result = session(im_tensor) + # post process + result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) + ma = torch.max(result) + mi = torch.min(result) + result = (result - mi) / (ma - mi) + # image to pil + im_array = (result * 255).cpu().data.numpy().astype(np.uint8) + + mask = np.squeeze(im_array) + if only_mask: + return mask + + pil_im = Image.fromarray(mask) + # paste the mask on the original image + new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) + new_im.paste(orig_bgr_image, mask=pil_im) + rgba_np_img = np.asarray(new_im) + return rgba_np_img diff --git a/custom-demo/back-end/plugins/gfpgan_plugin.py b/custom-demo/back-end/plugins/gfpgan_plugin.py new file mode 100644 index 0000000..619280b --- /dev/null +++ b/custom-demo/back-end/plugins/gfpgan_plugin.py @@ -0,0 +1,74 @@ +import cv2 +import numpy as np +from loguru import logger + +from iopaint.helper import download_model +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.schema import RunPluginRequest + + +class GFPGANPlugin(BasePlugin): + name = "GFPGAN" + support_gen_image = True + + def __init__(self, device, upscaler=None): + super().__init__() + from .gfpganer import MyGFPGANer + + url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" + model_md5 = "94d735072630ab734561130a47bc44f8" + model_path = download_model(url, model_md5) + logger.info(f"GFPGAN model path: {model_path}") + + import facexlib + + if hasattr(facexlib.detection.retinaface, "device"): + facexlib.detection.retinaface.device = device + + # Use GFPGAN for face enhancement + self.face_enhancer = MyGFPGANer( + model_path=model_path, + upscale=1, + arch="clean", + channel_multiplier=2, + device=device, + bg_upsampler=upscaler.model if upscaler is not None else None, + ) + self.face_enhancer.face_helper.face_det.mean_tensor.to(device) + self.face_enhancer.face_helper.face_det = ( + self.face_enhancer.face_helper.face_det.to(device) + ) + + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + weight = 0.5 + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") + _, _, bgr_output = self.face_enhancer.enhance( + bgr_np_img, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=weight, + ) + logger.info(f"GFPGAN output shape: {bgr_output.shape}") + + # try: + # if scale != 2: + # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 + # h, w = img.shape[0:2] + # output = cv2.resize( + # output, + # (int(w * scale / 2), int(h * scale / 2)), + # interpolation=interpolation, + # ) + # except Exception as error: + # print("wrong scale input.", error) + return bgr_output + + def check_dep(self): + try: + import gfpgan + except ImportError: + return ( + "gfpgan is not installed, please install it first. pip install gfpgan" + ) diff --git a/custom-demo/back-end/plugins/gfpganer.py b/custom-demo/back-end/plugins/gfpganer.py new file mode 100644 index 0000000..75a575d --- /dev/null +++ b/custom-demo/back-end/plugins/gfpganer.py @@ -0,0 +1,84 @@ +import os + +import torch +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from gfpgan import GFPGANv1Clean, GFPGANer +from torch.hub import get_dir + + +class MyGFPGANer(GFPGANer): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__( + self, + model_path, + upscale=2, + arch="clean", + channel_multiplier=2, + bg_upsampler=None, + device=None, + ): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + # initialize the GFP-GAN + if arch == "clean": + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True, + ) + elif arch == "RestoreFormer": + from gfpgan.archs.restoreformer_arch import RestoreFormer + + self.gfpgan = RestoreFormer() + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + use_parse=True, + device=self.device, + model_rootpath=model_dir, + ) + + loadnet = torch.load(model_path) + if "params_ema" in loadnet: + keyname = "params_ema" + else: + keyname = "params" + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) diff --git a/custom-demo/back-end/plugins/interactive_seg.py b/custom-demo/back-end/plugins/interactive_seg.py new file mode 100644 index 0000000..a270991 --- /dev/null +++ b/custom-demo/back-end/plugins/interactive_seg.py @@ -0,0 +1,107 @@ +import hashlib +from typing import List + +import numpy as np +import torch +from loguru import logger + +from iopaint.helper import download_model +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry +from iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor +from iopaint.schema import RunPluginRequest + +# 从小到大 +SEGMENT_ANYTHING_MODELS = { + "vit_b": { + "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "md5": "01ec64d29a2fca3f0661936605ae66f8", + }, + "vit_l": { + "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "md5": "0b3195507c641ddb6910d2bb5adee89c", + }, + "vit_h": { + "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "md5": "4b8939a88964f0f4ff5f5b2642c598a6", + }, + "mobile_sam": { + "url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt", + "md5": "f3c0d8cda613564d499310dab6c812cd", + }, + "sam_hq_vit_b": { + "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth", + "md5": "c6b8953247bcfdc8bb8ef91e36a6cacc", + }, + "sam_hq_vit_l": { + "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth", + "md5": "08947267966e4264fb39523eccc33f86", + }, + "sam_hq_vit_h": { + "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", + "md5": "3560f6b6a5a6edacd814a1325c39640a", + }, +} + + +class InteractiveSeg(BasePlugin): + name = "InteractiveSeg" + support_gen_mask = True + + def __init__(self, model_name, device): + super().__init__() + self.model_name = model_name + self.device = device + self._init_session(model_name) + + def _init_session(self, model_name: str): + model_path = download_model( + SEGMENT_ANYTHING_MODELS[model_name]["url"], + SEGMENT_ANYTHING_MODELS[model_name]["md5"], + ) + logger.info(f"SegmentAnything model path: {model_path}") + if "sam_hq" in model_name: + self.predictor = SamHQPredictor( + sam_model_registry[model_name](checkpoint=model_path).to(self.device) + ) + else: + self.predictor = SamPredictor( + sam_model_registry[model_name](checkpoint=model_path).to(self.device) + ) + self.prev_img_md5 = None + + def switch_model(self, new_model_name): + if self.model_name == new_model_name: + return + + logger.info( + f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}" + ) + self._init_session(new_model_name) + self.model_name = new_model_name + + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() + return self.forward(rgb_np_img, req.clicks, img_md5) + + @torch.inference_mode() + def forward(self, rgb_np_img, clicks: List[List], img_md5: str): + input_point = [] + input_label = [] + for click in clicks: + x = click[0] + y = click[1] + input_point.append([x, y]) + input_label.append(click[2]) + + if img_md5 and img_md5 != self.prev_img_md5: + self.prev_img_md5 = img_md5 + self.predictor.set_image(rgb_np_img) + + masks, scores, _ = self.predictor.predict( + point_coords=np.array(input_point), + point_labels=np.array(input_label), + multimask_output=False, + ) + mask = masks[0].astype(np.uint8) * 255 + return mask diff --git a/custom-demo/back-end/plugins/realesrgan.py b/custom-demo/back-end/plugins/realesrgan.py new file mode 100644 index 0000000..5275700 --- /dev/null +++ b/custom-demo/back-end/plugins/realesrgan.py @@ -0,0 +1,109 @@ +import cv2 +import numpy as np +import torch +from loguru import logger + +from iopaint.helper import download_model +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.schema import RunPluginRequest, RealESRGANModel + + +class RealESRGANUpscaler(BasePlugin): + name = "RealESRGAN" + support_gen_image = True + + def __init__(self, name, device, no_half=False): + super().__init__() + self.model_name = name + self.device = device + self.no_half = no_half + self._init_model(name) + + def _init_model(self, name): + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + from realesrgan.archs.srvgg_arch import SRVGGNetCompact + + REAL_ESRGAN_MODELS = { + RealESRGANModel.realesr_general_x4v3: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + "scale": 4, + "model": lambda: SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ), + "model_md5": "91a7644643c884ee00737db24e478156", + }, + RealESRGANModel.RealESRGAN_x4plus: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "scale": 4, + "model": lambda: RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ), + "model_md5": "99ec365d4afad750833258a1a24f44ca", + }, + RealESRGANModel.RealESRGAN_x4plus_anime_6B: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + "scale": 4, + "model": lambda: RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=6, + num_grow_ch=32, + scale=4, + ), + "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba", + }, + } + if name not in REAL_ESRGAN_MODELS: + raise ValueError(f"Unknown RealESRGAN model name: {name}") + model_info = REAL_ESRGAN_MODELS[name] + + model_path = download_model(model_info["url"], model_info["model_md5"]) + logger.info(f"RealESRGAN model path: {model_path}") + + self.model = RealESRGANer( + scale=model_info["scale"], + model_path=model_path, + model=model_info["model"](), + half=True if "cuda" in str(self.device) and not self.no_half else False, + tile=512, + tile_pad=10, + pre_pad=10, + device=self.device, + ) + + def switch_model(self, new_model_name: str): + if self.model_name == new_model_name: + return + self._init_model(new_model_name) + self.model_name = new_model_name + + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") + result = self.forward(bgr_np_img, req.scale) + logger.info(f"RealESRGAN output shape: {result.shape}") + return result + + @torch.inference_mode() + def forward(self, bgr_np_img, scale: float): + # 输出是 BGR + upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] + return upsampled + + def check_dep(self): + try: + import realesrgan + except ImportError: + return "RealESRGAN is not installed, please install it first. pip install realesrgan" diff --git a/custom-demo/back-end/plugins/remove_bg.py b/custom-demo/back-end/plugins/remove_bg.py new file mode 100644 index 0000000..64bf785 --- /dev/null +++ b/custom-demo/back-end/plugins/remove_bg.py @@ -0,0 +1,71 @@ +import os +import cv2 +import numpy as np +from loguru import logger +from torch.hub import get_dir + +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.schema import RunPluginRequest, RemoveBGModel + + +class RemoveBG(BasePlugin): + name = "RemoveBG" + support_gen_mask = True + support_gen_image = True + + def __init__(self, model_name): + super().__init__() + self.model_name = model_name + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + os.environ["U2NET_HOME"] = model_dir + + self._init_session(model_name) + + def _init_session(self, model_name: str): + if model_name == RemoveBGModel.briaai_rmbg_1_4: + from iopaint.plugins.briarmbg import ( + create_briarmbg_session, + briarmbg_process, + ) + + self.session = create_briarmbg_session() + self.remove = briarmbg_process + else: + from rembg import new_session, remove + + self.session = new_session(model_name=model_name) + self.remove = remove + + def switch_model(self, new_model_name): + if self.model_name == new_model_name: + return + + logger.info( + f"Switching removebg model from {self.model_name} to {new_model_name}" + ) + self._init_session(new_model_name) + self.model_name = new_model_name + + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + + # return BGRA image + output = self.remove(bgr_np_img, session=self.session) + return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) + + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + + # return BGR image, 255 means foreground, 0 means background + output = self.remove(bgr_np_img, session=self.session, only_mask=True) + return output + + def check_dep(self): + try: + import rembg + except ImportError: + return ( + "RemoveBG is not installed, please install it first. pip install rembg" + ) diff --git a/custom-demo/back-end/plugins/restoreformer.py b/custom-demo/back-end/plugins/restoreformer.py new file mode 100644 index 0000000..4e1d3e7 --- /dev/null +++ b/custom-demo/back-end/plugins/restoreformer.py @@ -0,0 +1,57 @@ +import cv2 +import numpy as np +from loguru import logger + +from iopaint.helper import download_model +from iopaint.plugins.base_plugin import BasePlugin +from iopaint.schema import RunPluginRequest + + +class RestoreFormerPlugin(BasePlugin): + name = "RestoreFormer" + support_gen_image = True + + def __init__(self, device, upscaler=None): + super().__init__() + from .gfpganer import MyGFPGANer + + url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" + model_md5 = "eaeeff6c4a1caa1673977cb374e6f699" + model_path = download_model(url, model_md5) + logger.info(f"RestoreFormer model path: {model_path}") + + import facexlib + + if hasattr(facexlib.detection.retinaface, "device"): + facexlib.detection.retinaface.device = device + + self.face_enhancer = MyGFPGANer( + model_path=model_path, + upscale=1, + arch="RestoreFormer", + channel_multiplier=2, + device=device, + bg_upsampler=upscaler.model if upscaler is not None else None, + ) + + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: + weight = 0.5 + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}") + _, _, bgr_output = self.face_enhancer.enhance( + bgr_np_img, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=weight, + ) + logger.info(f"RestoreFormer output shape: {bgr_output.shape}") + return bgr_output + + def check_dep(self): + try: + import gfpgan + except ImportError: + return ( + "gfpgan is not installed, please install it first. pip install gfpgan" + ) diff --git a/custom-demo/back-end/plugins/segment_anything/__init__.py b/custom-demo/back-end/plugins/segment_anything/__init__.py new file mode 100644 index 0000000..420f04b --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + build_sam_vit_h_hq, + build_sam_vit_l_hq, + build_sam_vit_b_hq, + sam_model_registry, +) +from .predictor import SamPredictor diff --git a/custom-demo/back-end/plugins/segment_anything/build_sam.py b/custom-demo/back-end/plugins/segment_anything/build_sam.py new file mode 100644 index 0000000..9b905ef --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/build_sam.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from iopaint.plugins.segment_anything.modeling.tiny_vit_sam import TinyViT + +from .modeling import ( + ImageEncoderViT, + MaskDecoder, + PromptEncoder, + Sam, + TwoWayTransformer, +) +from .modeling.image_encoder_hq import ImageEncoderViTHQ +from .modeling.mask_decoder import MaskDecoderHQ +from .modeling.sam_hq import SamHQ + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + + +def build_sam_vit_h_hq(checkpoint=None): + return _build_sam_hq( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +def build_sam_vit_l_hq(checkpoint=None): + return _build_sam_hq( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b_hq(checkpoint=None): + return _build_sam_hq( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "sam_hq_vit_h": build_sam_vit_h_hq, + "sam_hq_vit_l": build_sam_vit_l_hq, + "sam_hq_vit_b": build_sam_vit_b_hq, + "mobile_sam": build_sam_vit_t, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam + + +def _build_sam_hq( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = SamHQ( + image_encoder=ImageEncoderViTHQ( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=encoder_embed_dim, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + device = "cuda" if torch.cuda.is_available() else "cpu" + state_dict = torch.load(f, map_location=device) + info = sam.load_state_dict(state_dict, strict=False) + print(info) + for n, p in sam.named_parameters(): + if ( + "hf_token" not in n + and "hf_mlp" not in n + and "compress_vit_feat" not in n + and "embedding_encoder" not in n + and "embedding_maskfeature" not in n + ): + p.requires_grad = False + + return sam diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/__init__.py b/custom-demo/back-end/plugins/segment_anything/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/common.py b/custom-demo/back-end/plugins/segment_anything/modeling/common.py new file mode 100644 index 0000000..2bf1523 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/image_encoder.py b/custom-demo/back-end/plugins/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000..a6ad9ad --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/image_encoder_hq.py b/custom-demo/back-end/plugins/segment_anything/modeling/image_encoder_hq.py new file mode 100644 index 0000000..f12803b --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/image_encoder_hq.py @@ -0,0 +1,422 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViTHQ(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + interm_embeddings = [] + for blk in self.blocks: + x = blk(x) + if blk.window_size == 0: + interm_embeddings.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x, interm_embeddings + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + attn = attn.softmax(dim=-1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/mask_decoder.py b/custom-demo/back-end/plugins/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000..67e0f77 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,410 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for outptu + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + +# https://github.com/SysCV/sam-hq/blob/main/segment_anything/modeling/mask_decoder_hq.py#L17 +class MaskDecoderHQ(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + vit_dim: int = 1024, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + # HQ-SAM parameters + self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token + self.hf_mlp = MLP( + transformer_dim, transformer_dim, transformer_dim // 8, 3 + ) # corresponding new MLP layer for HQ-Ouptput-Token + self.num_mask_tokens = self.num_mask_tokens + 1 + + # three conv fusion layers for obtaining HQ-Feature + self.compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim), + nn.GELU(), + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 8, kernel_size=2, stride=2 + ), + ) + + self.embedding_encoder = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + ) + self.embedding_maskfeature = nn.Sequential( + nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1), + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + vit_features = interm_embeddings[0].permute( + 0, 3, 1, 2 + ) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat( + vit_features + ) + + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + hq_features=hq_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + # mask with highest score + mask_slice = slice(1, self.num_mask_tokens - 1) + iou_pred = iou_pred[:, mask_slice] + iou_pred, max_iou_idx = torch.max(iou_pred, dim=1) + iou_pred = iou_pred.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[ + torch.arange(masks_multi.size(0)), max_iou_idx + ].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + iou_pred = iou_pred[:, mask_slice] + masks_sam = masks[:, mask_slice] + + masks_hq = masks[:, slice(self.num_mask_tokens - 1, self.num_mask_tokens)] + if hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + hq_features: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], + dim=0, + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + upscaled_embedding_sam = self.output_upscaling(src) + upscaled_embedding_hq = self.embedding_maskfeature( + upscaled_embedding_sam + ) + hq_features.repeat(b, 1, 1, 1) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < self.num_mask_tokens - 1: + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + else: + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding_sam.shape + + masks_sam = ( + hyper_in[:, : self.num_mask_tokens - 1] + @ upscaled_embedding_sam.view(b, c, h * w) + ).view(b, -1, h, w) + masks_sam_hq = ( + hyper_in[:, self.num_mask_tokens - 1 :] + @ upscaled_embedding_hq.view(b, c, h * w) + ).view(b, -1, h, w) + masks = torch.cat([masks_sam, masks_sam_hq], dim=1) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/prompt_encoder.py b/custom-demo/back-end/plugins/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000..c3143f4 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/sam.py b/custom-demo/back-end/plugins/segment_anything/modeling/sam.py new file mode 100644 index 0000000..303bc2f --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/sam_hq.py b/custom-demo/back-end/plugins/segment_anything/modeling/sam_hq.py new file mode 100644 index 0000000..d2ae3a3 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/sam_hq.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class SamHQ(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + hq_token_only: bool =False, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings, interm_embeddings = self.image_encoder(input_images) + interm_embeddings = interm_embeddings[0] # early layer + + outputs = [] + for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x \ No newline at end of file diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/tiny_vit_sam.py b/custom-demo/back-end/plugins/segment_anything/modeling/tiny_vit_sam.py new file mode 100644 index 0000000..a5127c7 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/tiny_vit_sam.py @@ -0,0 +1,822 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import collections +import itertools +import math +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from typing import Tuple + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(itertools.repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class TimmDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(TimmDropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class Conv2d_BN(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1 + ): + super().__init__() + self.add_module( + "c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False) + ) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module("bn", bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1) * self.c.groups, + w.size(0), + w.shape[2:], + stride=self.c.stride, + padding=self.c.padding, + dilation=self.c.dilation, + groups=self.c.groups, + ) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f"(drop_prob={self.drop_prob})" + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN( + self.hidden_chans, + self.hidden_chans, + ks=3, + stride=1, + pad=1, + groups=self.hidden_chans, + ) + self.act2 = activation() + + self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c = 2 + if out_dim == 320 or out_dim == 448 or out_dim == 576: + stride_c = 1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__( + self, + dim, + input_resolution, + depth, + activation, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4.0, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation + ) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets)) + ) + self.register_buffer( + "attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False + ) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, "ab"): + del self.ab + else: + self.register_buffer( + "ab", + self.attention_biases[:, self.attention_bias_idxs], + persistent=False, + ) + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split( + [self.key_dim, self.key_dim, self.d], dim=3 + ) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] + if self.training + else self.ab + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r"""TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, "window_size must be greater than 0" + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + assert dim % num_heads == 0, "dim must be divisible by num_heads" + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention( + dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution + ) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, + drop=drop, + ) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim + ) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size - W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = ( + x.view(B, nH, self.window_size, nW, self.window_size, C) + .transpose(2, 3) + .reshape(B * nH * nW, self.window_size * self.window_size, C) + ) + x = self.attn(x) + # window reverse + x = ( + x.view(B, nH, nW, self.window_size, self.window_size, C) + .transpose(2, 3) + .reshape(B, pH, pW, C) + ) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class BasicLayer(nn.Module): + """A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation + ) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class TinyViT(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dims=[96, 192, 384, 768], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size = img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed( + in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation, + ) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict( + dim=embed_dims[i_layer], + input_resolution=( + patches_resolution[0] + // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] + // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + ), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs, + ) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = ( + nn.Linear(embed_dims[-1], num_classes) + if num_classes > 0 + else torch.nn.Identity() + ) + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + # print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, "lr_scale"), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"attention_biases"} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B, _, C = x.size() + x = x.view(B, 64, 64, C) + x = x.permute(0, 3, 1, 2) + x = self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + # x = self.norm_head(x) + # x = self.head(x) + return x diff --git a/custom-demo/back-end/plugins/segment_anything/modeling/transformer.py b/custom-demo/back-end/plugins/segment_anything/modeling/transformer.py new file mode 100644 index 0000000..f1a2812 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attenion layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/custom-demo/back-end/plugins/segment_anything/predictor.py b/custom-demo/back-end/plugins/segment_anything/predictor.py new file mode 100644 index 0000000..23d0649 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/predictor.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + from .utils.transforms import ResizeLongestSide + + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ + None, :, :, : + ] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + labels_torch = torch.as_tensor( + point_labels, dtype=torch.int, device=self.device + ) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor( + mask_input, dtype=torch.float, device=self.device + ) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks( + low_res_masks, self.input_size, self.original_size + ) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self.features is not None + ), "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/custom-demo/back-end/plugins/segment_anything/predictor_hq.py b/custom-demo/back-end/plugins/segment_anything/predictor_hq.py new file mode 100644 index 0000000..d8fd50f --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/predictor_hq.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamHQPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + # import pdb;pdb.set_trace() + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + # import pdb;pdb.set_trace() + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ + None, :, :, : + ] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features, self.interm_features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + labels_torch = torch.as_tensor( + point_labels, dtype=torch.int, device=self.device + ) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor( + mask_input, dtype=torch.float, device=self.device + ) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + hq_token_only=hq_token_only, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=self.interm_features, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks( + low_res_masks, self.input_size, self.original_size + ) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self.features is not None + ), "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/custom-demo/back-end/plugins/segment_anything/utils/__init__.py b/custom-demo/back-end/plugins/segment_anything/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/custom-demo/back-end/plugins/segment_anything/utils/transforms.py b/custom-demo/back-end/plugins/segment_anything/utils/transforms.py new file mode 100644 index 0000000..90f50ed --- /dev/null +++ b/custom-demo/back-end/plugins/segment_anything/utils/transforms.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords( + self, coords: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes( + self, boxes: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape( + oldh: int, oldw: int, long_side_length: int + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/custom-demo/back-end/requirements.txt b/custom-demo/back-end/requirements.txt new file mode 100644 index 0000000..9f27a10 --- /dev/null +++ b/custom-demo/back-end/requirements.txt @@ -0,0 +1,25 @@ +torch>=2.0.0 +opencv-python +diffusers==0.27.2 +huggingface_hub==0.25.2 +accelerate +peft==0.7.1 +transformers>=4.39.1 +safetensors +controlnet-aux==0.0.3 +fastapi==0.108.0 +uvicorn +python-multipart +python-socketio==5.7.2 +typer +pydantic>=2.5.2 +rich +loguru +yacs +piexif==1.1.3 +omegaconf +easydict +gradio==4.21.0 +typer-config==1.4.0 + +Pillow==9.5.0 # for AnyText diff --git a/custom-demo/back-end/runtime.py b/custom-demo/back-end/runtime.py new file mode 100644 index 0000000..7199f83 --- /dev/null +++ b/custom-demo/back-end/runtime.py @@ -0,0 +1,88 @@ +# https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py +import os +import platform +import sys +from pathlib import Path + +import packaging.version +from iopaint.schema import Device +from loguru import logger +from rich import print +from typing import Dict, Any + + +_PY_VERSION: str = sys.version.split()[0].rstrip("+") + +if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"): + import importlib_metadata # type: ignore +else: + import importlib.metadata as importlib_metadata # type: ignore + +_package_versions = {} + +_CANDIDATES = [ + "torch", + "torchvision", + "Pillow", + "diffusers", + "transformers", + "opencv-python", + "accelerate", + "iopaint", + "rembg", + "realesrgan", + "gfpgan", +] +# Check once at runtime +for name in _CANDIDATES: + _package_versions[name] = "N/A" + try: + _package_versions[name] = importlib_metadata.version(name) + except importlib_metadata.PackageNotFoundError: + pass + + +def dump_environment_info() -> Dict[str, str]: + """Dump information about the machine to help debugging issues.""" + + # Generic machine info + info: Dict[str, Any] = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + } + info.update(_package_versions) + print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") + return info + + +def check_device(device: Device) -> Device: + if device == Device.cuda: + import platform + + if platform.system() == "Darwin": + logger.warning("MacOS does not support cuda, use cpu instead") + return Device.cpu + else: + import torch + + if not torch.cuda.is_available(): + logger.warning("CUDA is not available, use cpu instead") + return Device.cpu + elif device == Device.mps: + import torch + + if not torch.backends.mps.is_available(): + logger.warning("mps is not available, use cpu instead") + return Device.cpu + return device + + +def setup_model_dir(model_dir: Path): + model_dir = model_dir.expanduser().absolute() + logger.info(f"Model directory: {model_dir}") + os.environ["U2NET_HOME"] = str(model_dir) + os.environ["XDG_CACHE_HOME"] = str(model_dir) + if not model_dir.exists(): + logger.info(f"Create model directory: {model_dir}") + model_dir.mkdir(exist_ok=True, parents=True) + return model_dir diff --git a/custom-demo/back-end/schema.py b/custom-demo/back-end/schema.py new file mode 100644 index 0000000..c8ba9ca --- /dev/null +++ b/custom-demo/back-end/schema.py @@ -0,0 +1,470 @@ +import random +from enum import Enum +from pathlib import Path +from typing import Optional, Literal, List + +from iopaint.const import ( + INSTRUCT_PIX2PIX_NAME, + KANDINSKY22_NAME, + POWERPAINT_NAME, + ANYTEXT_NAME, + SDXL_CONTROLNET_CHOICES, + SD2_CONTROLNET_CHOICES, + SD_CONTROLNET_CHOICES, +) +from loguru import logger +from pydantic import BaseModel, Field, field_validator, computed_field + + +class ModelType(str, Enum): + INPAINT = "inpaint" # LaMa, MAT... + DIFFUSERS_SD = "diffusers_sd" + DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint" + DIFFUSERS_SDXL = "diffusers_sdxl" + DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint" + DIFFUSERS_OTHER = "diffusers_other" + + +class ModelInfo(BaseModel): + name: str + path: str + model_type: ModelType + is_single_file_diffusers: bool = False + + @computed_field + @property + def need_prompt(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [ + INSTRUCT_PIX2PIX_NAME, + KANDINSKY22_NAME, + POWERPAINT_NAME, + ANYTEXT_NAME, + ] + + @computed_field + @property + def controlnets(self) -> List[str]: + if self.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + return SDXL_CONTROLNET_CHOICES + if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]: + if "sd2" in self.name.lower(): + return SD2_CONTROLNET_CHOICES + else: + return SD_CONTROLNET_CHOICES + if self.name == POWERPAINT_NAME: + return SD_CONTROLNET_CHOICES + return [] + + @computed_field + @property + def support_strength(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME] + + @computed_field + @property + def support_outpainting(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME] + + @computed_field + @property + def support_lcm_lora(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + @computed_field + @property + def support_controlnet(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + @computed_field + @property + def support_freeu(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [INSTRUCT_PIX2PIX_NAME] + + +class Choices(str, Enum): + @classmethod + def values(cls): + return [member.value for member in cls] + + +class RealESRGANModel(Choices): + realesr_general_x4v3 = "realesr-general-x4v3" + RealESRGAN_x4plus = "RealESRGAN_x4plus" + RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" + + +class RemoveBGModel(Choices): + u2net = "u2net" + u2netp = "u2netp" + u2net_human_seg = "u2net_human_seg" + u2net_cloth_seg = "u2net_cloth_seg" + silueta = "silueta" + isnet_general_use = "isnet-general-use" + briaai_rmbg_1_4 = "briaai/RMBG-1.4" + + +class Device(Choices): + cpu = "cpu" + cuda = "cuda" + mps = "mps" + + +class InteractiveSegModel(Choices): + vit_b = "vit_b" + vit_l = "vit_l" + vit_h = "vit_h" + sam_hq_vit_b = "sam_hq_vit_b" + sam_hq_vit_l = "sam_hq_vit_l" + sam_hq_vit_h = "sam_hq_vit_h" + mobile_sam = "mobile_sam" + + +class PluginInfo(BaseModel): + name: str + support_gen_image: bool = False + support_gen_mask: bool = False + + +class CV2Flag(str, Enum): + INPAINT_NS = "INPAINT_NS" + INPAINT_TELEA = "INPAINT_TELEA" + + +class HDStrategy(str, Enum): + # Use original image size + ORIGINAL = "Original" + # Resize the longer side of the image to a specific size(hd_strategy_resize_limit), + # then do inpainting on the resized image. Finally, resize the inpainting result to the original size. + # The area outside the mask will not lose quality. + RESIZE = "Resize" + # Crop masking area(with a margin controlled by hd_strategy_crop_margin) from the original image to do inpainting + CROP = "Crop" + + +class LDMSampler(str, Enum): + ddim = "ddim" + plms = "plms" + + +class SDSampler(str, Enum): + dpm_plus_plus_2m = "DPM++ 2M" + dpm_plus_plus_2m_karras = "DPM++ 2M Karras" + dpm_plus_plus_2m_sde = "DPM++ 2M SDE" + dpm_plus_plus_2m_sde_karras = "DPM++ 2M SDE Karras" + dpm_plus_plus_sde = "DPM++ SDE" + dpm_plus_plus_sde_karras = "DPM++ SDE Karras" + dpm2 = "DPM2" + dpm2_karras = "DPM2 Karras" + dpm2_a = "DPM2 a" + dpm2_a_karras = "DPM2 a Karras" + euler = "Euler" + euler_a = "Euler a" + heun = "Heun" + lms = "LMS" + lms_karras = "LMS Karras" + + ddim = "DDIM" + pndm = "PNDM" + uni_pc = "UniPC" + lcm = "LCM" + + +class FREEUConfig(BaseModel): + s1: float = 0.9 + s2: float = 0.2 + b1: float = 1.2 + b2: float = 1.4 + + +class PowerPaintTask(str, Enum): + text_guided = "text-guided" + shape_guided = "shape-guided" + object_remove = "object-remove" + outpainting = "outpainting" + + +class ApiConfig(BaseModel): + host: str + port: int + inbrowser: bool + model: str + no_half: bool + low_mem: bool + cpu_offload: bool + disable_nsfw_checker: bool + local_files_only: bool + cpu_textencoder: bool + device: Device + input: Optional[Path] + output_dir: Optional[Path] + quality: int + enable_interactive_seg: bool + interactive_seg_model: InteractiveSegModel + interactive_seg_device: Device + enable_remove_bg: bool + remove_bg_model: str + enable_anime_seg: bool + enable_realesrgan: bool + realesrgan_device: Device + realesrgan_model: RealESRGANModel + enable_gfpgan: bool + gfpgan_device: Device + enable_restoreformer: bool + restoreformer_device: Device + + +class InpaintRequest(BaseModel): + image: Optional[str] = Field(None, description="base64 encoded image") + mask: Optional[str] = Field(None, description="base64 encoded mask") + + ldm_steps: int = Field(20, description="Steps for ldm model.") + ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.") + zits_wireframe: bool = Field(True, description="Enable wireframe for zits model.") + + hd_strategy: str = Field( + HDStrategy.CROP, + description="Different way to preprocess image, only used by erase models(e.g. lama/mat)", + ) + hd_strategy_crop_trigger_size: int = Field( + 800, + description="Crop trigger size for hd_strategy=CROP, if the longer side of the image is larger than this value, use crop strategy", + ) + hd_strategy_crop_margin: int = Field( + 128, description="Crop margin for hd_strategy=CROP" + ) + hd_strategy_resize_limit: int = Field( + 1280, description="Resize limit for hd_strategy=RESIZE" + ) + + prompt: str = Field("", description="Prompt for diffusion models.") + negative_prompt: str = Field( + "", description="Negative prompt for diffusion models." + ) + use_croper: bool = Field( + False, description="Crop image before doing diffusion inpainting" + ) + croper_x: int = Field(0, description="Crop x for croper") + croper_y: int = Field(0, description="Crop y for croper") + croper_height: int = Field(512, description="Crop height for croper") + croper_width: int = Field(512, description="Crop width for croper") + + use_extender: bool = Field( + False, description="Extend image before doing sd outpainting" + ) + extender_x: int = Field(0, description="Extend x for extender") + extender_y: int = Field(0, description="Extend y for extender") + extender_height: int = Field(640, description="Extend height for extender") + extender_width: int = Field(640, description="Extend width for extender") + + sd_scale: float = Field( + 1.0, + description="Resize the image before doing sd inpainting, the area outside the mask will not lose quality.", + gt=0.0, + le=1.0, + ) + sd_mask_blur: int = Field( + 11, + description="Blur the edge of mask area. The higher the number the smoother blend with the original image", + ) + sd_strength: float = Field( + 1.0, + description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image", + le=1.0, + ) + sd_steps: int = Field( + 50, + description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.", + ) + sd_guidance_scale: float = Field( + 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.", + ) + sd_sampler: str = Field( + SDSampler.uni_pc, description="Sampler for diffusion model." + ) + sd_seed: int = Field( + 42, + description="Seed for diffusion model. -1 mean random seed", + validate_default=True, + ) + sd_match_histograms: bool = Field( + False, + description="Match histograms between inpainting area and original image.", + ) + + sd_outpainting_softness: float = Field(20.0) + sd_outpainting_space: float = Field(20.0) + + sd_freeu: bool = Field( + False, + description="Enable freeu mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu", + ) + sd_freeu_config: FREEUConfig = FREEUConfig() + + sd_lcm_lora: bool = Field( + False, + description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage", + ) + + sd_keep_unmasked_area: bool = Field( + True, description="Keep unmasked area unchanged" + ) + + cv2_flag: CV2Flag = Field( + CV2Flag.INPAINT_NS, + description="Flag for opencv inpainting: https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07", + ) + cv2_radius: int = Field( + 4, + description="Radius of a circular neighborhood of each point inpainted that is considered by the algorithm", + ) + + # Paint by Example + paint_by_example_example_image: Optional[str] = Field( + None, description="Base64 encoded example image for paint by example model" + ) + + # InstructPix2Pix + p2p_image_guidance_scale: float = Field(1.5, description="Image guidance scale") + + # ControlNet + enable_controlnet: bool = Field(False, description="Enable controlnet") + controlnet_conditioning_scale: float = Field( + 0.4, description="Conditioning scale", ge=0.0, le=1.0 + ) + controlnet_method: str = Field( + "lllyasviel/control_v11p_sd15_canny", description="Controlnet method" + ) + + # PowerPaint + powerpaint_task: PowerPaintTask = Field( + PowerPaintTask.text_guided, description="PowerPaint task" + ) + fitting_degree: float = Field( + 1.0, + description="Control the fitting degree of the generated objects to the mask shape.", + gt=0.0, + le=1.0, + ) + + @field_validator("sd_seed") + @classmethod + def sd_seed_validator(cls, v: int) -> int: + if v == -1: + return random.randint(1, 99999999) + return v + + @field_validator("controlnet_conditioning_scale") + @classmethod + def validate_field(cls, v: float, values): + use_extender = values.data["use_extender"] + enable_controlnet = values.data["enable_controlnet"] + if use_extender and enable_controlnet: + logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0") + return 0 + return v + + @field_validator("sd_strength") + @classmethod + def validate_sd_strength(cls, v: float, values): + use_extender = values.data["use_extender"] + if use_extender: + logger.info(f"Extender is enabled, set sd_strength=1") + return 1.0 + return v + + +class RunPluginRequest(BaseModel): + name: str + image: str = Field(..., description="base64 encoded image") + clicks: List[List[int]] = Field( + [], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]" + ) + scale: float = Field(2.0, description="Scale for upscaling") + + +MediaTab = Literal["input", "output"] + + +class MediasResponse(BaseModel): + name: str + height: int + width: int + ctime: float + mtime: float + + +class GenInfoResponse(BaseModel): + prompt: str = "" + negative_prompt: str = "" + + +class ServerConfigResponse(BaseModel): + plugins: List[PluginInfo] + modelInfos: List[ModelInfo] + removeBGModel: RemoveBGModel + removeBGModels: List[RemoveBGModel] + realesrganModel: RealESRGANModel + realesrganModels: List[RealESRGANModel] + interactiveSegModel: InteractiveSegModel + interactiveSegModels: List[InteractiveSegModel] + enableFileManager: bool + enableAutoSaving: bool + enableControlnet: bool + controlnetMethod: Optional[str] + disableModelSwitch: bool + isDesktop: bool + samplers: List[str] + + +class SwitchModelRequest(BaseModel): + name: str + + +class SwitchPluginModelRequest(BaseModel): + plugin_name: str + model_name: str + + +AdjustMaskOperate = Literal["expand", "shrink", "reverse"] + + +class AdjustMaskRequest(BaseModel): + mask: str = Field( + ..., description="base64 encoded mask. 255 means area to do inpaint" + ) + operate: AdjustMaskOperate = Field(..., description="expand/shrink/reverse") + kernel_size: int = Field(5, description="Kernel size for expanding mask") diff --git a/custom-demo/back-end/tests/.gitignore b/custom-demo/back-end/tests/.gitignore new file mode 100644 index 0000000..89b7717 --- /dev/null +++ b/custom-demo/back-end/tests/.gitignore @@ -0,0 +1,2 @@ +*_result.png +result/ \ No newline at end of file diff --git a/custom-demo/back-end/tests/__init__.py b/custom-demo/back-end/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/tests/anime_test.png b/custom-demo/back-end/tests/anime_test.png new file mode 100644 index 0000000..6b86838 Binary files /dev/null and b/custom-demo/back-end/tests/anime_test.png differ diff --git a/custom-demo/back-end/tests/anytext_mask.jpg b/custom-demo/back-end/tests/anytext_mask.jpg new file mode 100644 index 0000000..43d8b12 Binary files /dev/null and b/custom-demo/back-end/tests/anytext_mask.jpg differ diff --git a/custom-demo/back-end/tests/anytext_ref.jpg b/custom-demo/back-end/tests/anytext_ref.jpg new file mode 100644 index 0000000..c36b3c5 Binary files /dev/null and b/custom-demo/back-end/tests/anytext_ref.jpg differ diff --git a/custom-demo/back-end/tests/bunny.jpeg b/custom-demo/back-end/tests/bunny.jpeg new file mode 100644 index 0000000..3727a45 Binary files /dev/null and b/custom-demo/back-end/tests/bunny.jpeg differ diff --git a/custom-demo/back-end/tests/cat.png b/custom-demo/back-end/tests/cat.png new file mode 100644 index 0000000..dee9eb6 Binary files /dev/null and b/custom-demo/back-end/tests/cat.png differ diff --git a/custom-demo/back-end/tests/icc_profile_test.jpg b/custom-demo/back-end/tests/icc_profile_test.jpg new file mode 100644 index 0000000..b603ef9 Binary files /dev/null and b/custom-demo/back-end/tests/icc_profile_test.jpg differ diff --git a/custom-demo/back-end/tests/icc_profile_test.png b/custom-demo/back-end/tests/icc_profile_test.png new file mode 100644 index 0000000..90d18ac Binary files /dev/null and b/custom-demo/back-end/tests/icc_profile_test.png differ diff --git a/custom-demo/back-end/tests/image.png b/custom-demo/back-end/tests/image.png new file mode 100644 index 0000000..74c7a7b Binary files /dev/null and b/custom-demo/back-end/tests/image.png differ diff --git a/custom-demo/back-end/tests/mask.png b/custom-demo/back-end/tests/mask.png new file mode 100644 index 0000000..29cf20b Binary files /dev/null and b/custom-demo/back-end/tests/mask.png differ diff --git a/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo.png b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo.png new file mode 100644 index 0000000..e84dfc8 Binary files /dev/null and b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo.png differ diff --git a/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_all_mask.png b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_all_mask.png new file mode 100644 index 0000000..e69de29 diff --git a/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_mask.png b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_mask.png new file mode 100644 index 0000000..7f3c753 Binary files /dev/null and b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_mask.png differ diff --git a/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_mask_blur.png b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_mask_blur.png new file mode 100644 index 0000000..a630379 Binary files /dev/null and b/custom-demo/back-end/tests/overture-creations-5sI6fQgYIuo_mask_blur.png differ diff --git a/custom-demo/back-end/tests/png_parameter_test.png b/custom-demo/back-end/tests/png_parameter_test.png new file mode 100644 index 0000000..dc18bce Binary files /dev/null and b/custom-demo/back-end/tests/png_parameter_test.png differ diff --git a/custom-demo/back-end/tests/test_adjust_mask.py b/custom-demo/back-end/tests/test_adjust_mask.py new file mode 100644 index 0000000..1f01713 --- /dev/null +++ b/custom-demo/back-end/tests/test_adjust_mask.py @@ -0,0 +1,17 @@ +import cv2 +from iopaint.helper import adjust_mask +from iopaint.tests.utils import current_dir, save_dir + +mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png" + + +def test_adjust_mask(): + mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + res_mask = adjust_mask(mask, 0, "expand") + cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask) + res_mask = adjust_mask(mask, 40, "expand") + cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask) + res_mask = adjust_mask(mask, 20, "shrink") + cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask) + res_mask = adjust_mask(mask, 20, "reverse") + cv2.imwrite(str(save_dir / "adjust_mask_reverse.png"), res_mask) diff --git a/custom-demo/back-end/tests/test_anytext.py b/custom-demo/back-end/tests/test_anytext.py new file mode 100644 index 0000000..996176f --- /dev/null +++ b/custom-demo/back-end/tests/test_anytext.py @@ -0,0 +1,45 @@ +import os + +from iopaint.tests.utils import check_device, get_config, assert_equal + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +def test_anytext(device): + sd_steps = check_device(device) + model = ModelManager( + name="Sanster/AnyText", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt='Characters written in chalk on the blackboard that says "DADDY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks', + negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture", + sd_steps=sd_steps, + sd_guidance_scale=9.0, + sd_seed=66273235, + sd_match_histograms=True + ) + + assert_equal( + model, + cfg, + f"anytext.png", + img_p=current_dir / "anytext_ref.jpg", + mask_p=current_dir / "anytext_mask.jpg", + ) diff --git a/custom-demo/back-end/tests/test_controlnet.py b/custom-demo/back-end/tests/test_controlnet.py new file mode 100644 index 0000000..c271345 --- /dev/null +++ b/custom-demo/back-end/tests/test_controlnet.py @@ -0,0 +1,118 @@ +import os + +from iopaint.const import SD_CONTROLNET_CHOICES +from iopaint.tests.utils import current_dir, check_device, get_config, assert_equal + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler + + +model_name = "runwayml/stable-diffusion-inpainting" + + +def convert_controlnet_method_name(name): + return name.replace("/", "--") + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("controlnet_method", [SD_CONTROLNET_CHOICES[0]]) +def test_runway_sd_1_5(device, controlnet_method): + sd_steps = check_device(device) + + model = ModelManager( + name=model_name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=device == "cuda", + enable_controlnet=True, + controlnet_method=controlnet_method, + ) + + cfg = get_config( + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + enable_controlnet=True, + controlnet_conditioning_scale=0.5, + controlnet_method=controlnet_method, + ) + name = f"device_{device}" + + assert_equal( + model, + cfg, + f"sd_controlnet_{convert_controlnet_method_name(controlnet_method)}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +def test_controlnet_switch(device): + sd_steps = check_device(device) + model = ModelManager( + name=model_name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + enable_controlnet=True, + controlnet_method="lllyasviel/control_v11p_sd15_canny", + ) + cfg = get_config( + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + enable_controlnet=True, + controlnet_method="lllyasviel/control_v11f1p_sd15_depth", + ) + + assert_equal( + model, + cfg, + f"controlnet_switch_canny_to_depth_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.2 + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize( + "local_file", ["sd-v1-5-inpainting.ckpt", "v1-5-pruned-emaonly.safetensors"] +) +def test_local_file_path(device, local_file): + sd_steps = check_device(device) + + controlnet_kwargs = dict( + enable_controlnet=True, + controlnet_method=SD_CONTROLNET_CHOICES[0], + ) + + model = ModelManager( + name=local_file, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + **controlnet_kwargs, + ) + cfg = get_config( + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + **controlnet_kwargs, + ) + + name = f"device_{device}" + + assert_equal( + model, + cfg, + f"{convert_controlnet_method_name(controlnet_kwargs['controlnet_method'])}_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/custom-demo/back-end/tests/test_instruct_pix2pix.py b/custom-demo/back-end/tests/test_instruct_pix2pix.py new file mode 100644 index 0000000..f1ab4e2 --- /dev/null +++ b/custom-demo/back-end/tests/test_instruct_pix2pix.py @@ -0,0 +1,40 @@ +from pathlib import Path + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy +from iopaint.tests.utils import get_config, check_device, assert_equal, current_dir + +model_name = "timbrooks/instruct-pix2pix" + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("disable_nsfw", [True, False]) +@pytest.mark.parametrize("cpu_offload", [False, True]) +def test_instruct_pix2pix(device, disable_nsfw, cpu_offload): + sd_steps = check_device(device) + model = ModelManager( + name=model_name, + device=torch.device(device), + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=False, + cpu_offload=cpu_offload, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="What if it were snowing?", + sd_steps=sd_steps + ) + + name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}" + + assert_equal( + model, + cfg, + f"instruct_pix2pix_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3, + ) diff --git a/custom-demo/back-end/tests/test_load_img.py b/custom-demo/back-end/tests/test_load_img.py new file mode 100644 index 0000000..f7071bf --- /dev/null +++ b/custom-demo/back-end/tests/test_load_img.py @@ -0,0 +1,19 @@ +from iopaint.helper import load_img +from iopaint.tests.utils import current_dir + +png_img_p = current_dir / "image.png" +jpg_img_p = current_dir / "bunny.jpeg" + + +def test_load_png_image(): + with open(png_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (256, 256, 3) + assert alpha_channel.shape == (256, 256) + + +def test_load_jpg_image(): + with open(jpg_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (394, 448, 3) + assert alpha_channel is None diff --git a/custom-demo/back-end/tests/test_low_mem.py b/custom-demo/back-end/tests/test_low_mem.py new file mode 100644 index 0000000..70e8801 --- /dev/null +++ b/custom-demo/back-end/tests/test_low_mem.py @@ -0,0 +1,131 @@ +import os + +from loguru import logger + +from iopaint.tests.utils import check_device, get_config, assert_equal, current_dir + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler, FREEUConfig + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +def test_runway_sd_1_5_low_mem(device): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + + all_samplers = [member.value for member in SDSampler.__members__.values()] + print(all_samplers) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_sampler=SDSampler.ddim, + ) + + name = f"device_{device}" + + assert_equal( + model, + cfg, + f"runway_sd_{name}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.lcm]) +def test_runway_sd_lcm_lora_low_mem(device, sampler): + check_device(device) + + sd_steps = 5 + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=2, + sd_lcm_lora=True, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_1_5_lcm_lora_device_{device}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_freeu(device, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_freeu=True, + sd_freeu_config=FREEUConfig(), + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_1_5_freeu_device_{device}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_norm_sd_model(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-v1-5", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + cfg = get_config( + strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_{device}_norm_sd_model_device_{device}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/custom-demo/back-end/tests/test_match_histograms.py b/custom-demo/back-end/tests/test_match_histograms.py new file mode 100644 index 0000000..c20a283 --- /dev/null +++ b/custom-demo/back-end/tests/test_match_histograms.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import SDSampler, HDStrategy +from iopaint.tests.utils import check_device, get_config, assert_equal, current_dir + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sd_match_histograms(device, sampler): + sd_steps = check_device(device) + + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_lcm_lora=False, + sd_match_histograms=True, + sd_sampler=sampler + ) + + assert_equal( + model, + cfg, + f"runway_sd_1_5_device_{device}_match_histograms.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/custom-demo/back-end/tests/test_model.py b/custom-demo/back-end/tests/test_model.py new file mode 100644 index 0000000..dd84b12 --- /dev/null +++ b/custom-demo/back-end/tests/test_model.py @@ -0,0 +1,160 @@ +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, LDMSampler +from iopaint.tests.utils import assert_equal, get_config, current_dir, check_device + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +def test_lama(device, strategy): + check_device(device) + model = ModelManager(name="lama", device=device) + assert_equal( + model, + get_config(strategy=strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + get_config(strategy=strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png", + fx=1.3, + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms]) +def test_ldm(device, strategy, ldm_sampler): + check_device(device) + model = ModelManager(name="ldm", device=device) + cfg = get_config(strategy=strategy, ldm_sampler=ldm_sampler) + assert_equal( + model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png" + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png", + fx=fx, + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("zits_wireframe", [False, True]) +def test_zits(device, strategy, zits_wireframe): + check_device(device) + model = ModelManager(name="zits", device=device) + cfg = get_config(strategy=strategy, zits_wireframe=zits_wireframe) + assert_equal( + model, + cfg, + f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png", + fx=fx, + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("no_half", [True, False]) +def test_mat(device, strategy, no_half): + check_device(device) + model = ModelManager(name="mat", device=device, no_half=no_half) + cfg = get_config(strategy=strategy) + + assert_equal( + model, + cfg, + f"mat_{strategy.capitalize()}_result.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_fcf(device, strategy): + check_device(device) + model = ModelManager(name="fcf", device=device) + cfg = get_config(strategy=strategy) + + assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=2, fy=2) + assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=3.8, fy=2) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("cv2_flag", ["INPAINT_NS", "INPAINT_TELEA"]) +@pytest.mark.parametrize("cv2_radius", [3, 15]) +def test_cv2(strategy, cv2_flag, cv2_radius): + model = ModelManager( + name="cv2", + device=torch.device("cpu"), + ) + cfg = get_config(strategy=strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius) + assert_equal( + model, + cfg, + f"cv2_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +def test_manga(device, strategy): + check_device(device) + model = ModelManager( + name="manga", + device=torch.device(device), + ) + cfg = get_config(strategy=strategy) + assert_equal( + model, + cfg, + f"manga_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_mi_gan(device, strategy): + check_device(device) + model = ModelManager( + name="migan", + device=torch.device(device), + ) + cfg = get_config(strategy=strategy) + assert_equal( + model, + cfg, + f"migan_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.5, + fy=1.7 + ) diff --git a/custom-demo/back-end/tests/test_model_md5.py b/custom-demo/back-end/tests/test_model_md5.py new file mode 100644 index 0000000..3a81d72 --- /dev/null +++ b/custom-demo/back-end/tests/test_model_md5.py @@ -0,0 +1,16 @@ +def test_load_model(): + from iopaint.plugins import InteractiveSeg + from iopaint.model_manager import ModelManager + + interactive_seg_model = InteractiveSeg("vit_l", "cpu") + + models = ["lama", "ldm", "zits", "mat", "fcf", "manga", "migan"] + for m in models: + ModelManager( + name=m, + device="cpu", + no_half=False, + disable_nsfw=False, + sd_cpu_textencoder=True, + cpu_offload=True, + ) diff --git a/custom-demo/back-end/tests/test_model_switch.py b/custom-demo/back-end/tests/test_model_switch.py new file mode 100644 index 0000000..735e1bd --- /dev/null +++ b/custom-demo/back-end/tests/test_model_switch.py @@ -0,0 +1,70 @@ +import os + +from iopaint.schema import InpaintRequest + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import torch + +from iopaint.model_manager import ModelManager + + +def test_model_switch(): + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + enable_controlnet=True, + controlnet_method="lllyasviel/control_v11p_sd15_canny", + device=torch.device("mps"), + disable_nsfw=True, + sd_cpu_textencoder=True, + cpu_offload=False, + ) + + model.switch("lama") + + +def test_controlnet_switch_onoff(caplog): + name = "runwayml/stable-diffusion-inpainting" + model = ModelManager( + name=name, + enable_controlnet=True, + controlnet_method="lllyasviel/control_v11p_sd15_canny", + device=torch.device("mps"), + disable_nsfw=True, + sd_cpu_textencoder=True, + cpu_offload=False, + ) + + model.switch_controlnet_method( + InpaintRequest( + name=name, + enable_controlnet=False, + ) + ) + + assert "Disable controlnet" in caplog.text + + +def test_switch_controlnet_method(caplog): + name = "runwayml/stable-diffusion-inpainting" + old_method = "lllyasviel/control_v11p_sd15_canny" + new_method = "lllyasviel/control_v11p_sd15_openpose" + model = ModelManager( + name=name, + enable_controlnet=True, + controlnet_method=old_method, + device=torch.device("mps"), + disable_nsfw=True, + sd_cpu_textencoder=True, + cpu_offload=False, + ) + + model.switch_controlnet_method( + InpaintRequest( + name=name, + enable_controlnet=True, + controlnet_method=new_method, + ) + ) + + assert f"Switch Controlnet method from {old_method} to {new_method}" in caplog.text diff --git a/custom-demo/back-end/tests/test_outpainting.py b/custom-demo/back-end/tests/test_outpainting.py new file mode 100644 index 0000000..024d701 --- /dev/null +++ b/custom-demo/back-end/tests/test_outpainting.py @@ -0,0 +1,138 @@ +import os + +from iopaint.tests.utils import current_dir, check_device + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler +from iopaint.tests.test_model import get_config, assert_equal + + +@pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"]) +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize( + "rect", + [ + [0, -100, 512, 512 - 128 + 100], + [0, 128, 512, 512 - 128 + 100], + [128, 0, 512 - 128 + 100, 512], + [-100, 0, 512 - 128 + 100, 512], + [0, 0, 512, 512 + 200], + [0, 0, 512 + 200, 512], + [-100, -100, 512 + 200, 512 + 200], + ], +) +def test_outpainting(name, device, rect): + sd_steps = check_device(device) + + model = ModelManager( + name=name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + prompt="a dog sitting on a bench in the park", + sd_steps=sd_steps, + use_extender=True, + extender_x=rect[0], + extender_y=rect[1], + extender_width=rect[2], + extender_height=rect[3], + sd_guidance_scale=8.0, + sd_sampler=SDSampler.dpm_plus_plus_2m, + ) + + assert_equal( + model, + cfg, + f"{name.replace('/', '--')}_outpainting_{'_'.join(map(str, rect))}_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"]) +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize( + "rect", + [ + [-128, -128, 768, 768], + ], +) +def test_kandinsky_outpainting(name, device, rect): + sd_steps = check_device(device) + + model = ModelManager( + name=name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + prompt="a cat", + negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature", + sd_steps=sd_steps, + use_extender=True, + extender_x=rect[0], + extender_y=rect[1], + extender_width=rect[2], + extender_height=rect[3], + sd_guidance_scale=7, + sd_sampler=SDSampler.dpm_plus_plus_2m, + ) + + assert_equal( + model, + cfg, + f"{name.replace('/', '--')}_outpainting_{'_'.join(map(str, rect))}_device_{device}.png", + img_p=current_dir / "cat.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1, + fy=1, + ) + + +@pytest.mark.parametrize("name", ["Sanster/PowerPaint-V1-stable-diffusion-inpainting"]) +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize( + "rect", + [ + [-100, -100, 512 + 200, 512 + 200], + ], +) +def test_powerpaint_outpainting(name, device, rect): + sd_steps = check_device(device) + + model = ModelManager( + name=name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True + ) + cfg = get_config( + prompt="a dog sitting on a bench in the park", + sd_steps=sd_steps, + use_extender=True, + extender_x=rect[0], + extender_y=rect[1], + extender_width=rect[2], + extender_height=rect[3], + sd_guidance_scale=8.0, + sd_sampler=SDSampler.dpm_plus_plus_2m, + powerpaint_task="outpainting", + ) + + assert_equal( + model, + cfg, + f"{name.replace('/', '--')}_outpainting_{'_'.join(map(str, rect))}_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/custom-demo/back-end/tests/test_paint_by_example.py b/custom-demo/back-end/tests/test_paint_by_example.py new file mode 100644 index 0000000..27b8a77 --- /dev/null +++ b/custom-demo/back-end/tests/test_paint_by_example.py @@ -0,0 +1,55 @@ +import cv2 +import pytest +from PIL import Image + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy +from iopaint.tests.utils import ( + current_dir, + get_config, + get_data, + save_dir, + check_device, +) + +model_name = "Fantasy-Studio/Paint-by-Example" + + +def assert_equal( + model, + config, + save_name: str, + fx: float = 1, + fy: float = 1, + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + example_p=current_dir / "bunny.jpeg", +): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + + example_image = cv2.imread(str(example_p)) + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB) + example_image = cv2.resize( + example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA + ) + + print(f"Input image shape: {img.shape}, example_image: {example_image.shape}") + config.paint_by_example_example_image = Image.fromarray(example_image) + res = model(img, mask, config) + cv2.imwrite(str(save_dir / save_name), res) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +def test_paint_by_example(device): + sd_steps = check_device(device) + model = ModelManager(name=model_name, device=device, disable_nsfw=True) + cfg = get_config(strategy=HDStrategy.ORIGINAL, sd_steps=sd_steps) + assert_equal( + model, + cfg, + f"paint_by_example_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3, + ) diff --git a/custom-demo/back-end/tests/test_plugins.py b/custom-demo/back-end/tests/test_plugins.py new file mode 100644 index 0000000..aa7d367 --- /dev/null +++ b/custom-demo/back-end/tests/test_plugins.py @@ -0,0 +1,121 @@ +import hashlib +import os +import time +from PIL import Image + +from iopaint.helper import encode_pil_to_base64, gen_frontend_mask +from iopaint.plugins.anime_seg import AnimeSeg +from iopaint.schema import RunPluginRequest, RemoveBGModel, InteractiveSegModel +from iopaint.tests.utils import check_device, current_dir, save_dir + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import cv2 +import pytest + +from iopaint.plugins import ( + RemoveBG, + RealESRGANUpscaler, + GFPGANPlugin, + RestoreFormerPlugin, + InteractiveSeg, +) + +img_p = current_dir / "bunny.jpeg" +img_bytes = open(img_p, "rb").read() +bgr_img = cv2.imread(str(img_p)) +rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) +rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {}) +bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {}) + + +def _save(img, name): + cv2.imwrite(str(save_dir / name), img) + + +def test_remove_bg(): + model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4) + rgba_np_img = model.gen_image( + rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) + ) + res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA) + _save(res, "test_remove_bg.png") + + bgr_np_img = model.gen_mask( + rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) + ) + + res_mask = gen_frontend_mask(bgr_np_img) + _save(res_mask, "test_remove_bg_frontend_mask.png") + + assert len(bgr_np_img.shape) == 2 + _save(bgr_np_img, "test_remove_bg_mask.jpeg") + + +def test_anime_seg(): + model = AnimeSeg() + img = cv2.imread(str(current_dir / "anime_test.png")) + img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {}) + res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) + assert len(res.shape) == 3 + assert res.shape[-1] == 4 + _save(res, "test_anime_seg.png") + + res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) + assert len(res.shape) == 2 + _save(res, "test_anime_seg_mask.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_upscale(device): + check_device(device) + model = RealESRGANUpscaler("realesr-general-x4v3", device) + res = model.gen_image( + rgb_img, + RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2), + ) + _save(res, f"test_upscale_x2_{device}.png") + + res = model.gen_image( + rgb_img, + RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4), + ) + _save(res, f"test_upscale_x4_{device}.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_gfpgan(device): + check_device(device) + model = GFPGANPlugin(device) + res = model.gen_image( + rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64) + ) + _save(res, f"test_gfpgan_{device}.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_restoreformer(device): + check_device(device) + model = RestoreFormerPlugin(device) + res = model.gen_image( + rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64) + ) + _save(res, f"test_restoreformer_{device}.png") + + +@pytest.mark.parametrize("name", InteractiveSegModel.values()) +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_segment_anything(name, device): + check_device(device) + model = InteractiveSeg(name, device) + new_mask = model.gen_mask( + rgb_img, + RunPluginRequest( + name=InteractiveSeg.name, + image=rgb_img_base64, + clicks=([[448 // 2, 394 // 2, 1]]), + ), + ) + + save_name = f"test_segment_anything_{name}_{device}.png" + _save(new_mask, save_name) diff --git a/custom-demo/back-end/tests/test_save_exif.py b/custom-demo/back-end/tests/test_save_exif.py new file mode 100644 index 0000000..5c19810 --- /dev/null +++ b/custom-demo/back-end/tests/test_save_exif.py @@ -0,0 +1,59 @@ +import io +import tempfile +from pathlib import Path +from typing import List + +from PIL import Image + +from iopaint.helper import pil_to_bytes, load_img + +current_dir = Path(__file__).parent.absolute().resolve() + + +def print_exif(exif): + for k, v in exif.items(): + print(f"{k}: {v}") + + +def extra_info(img_p: Path): + ext = img_p.suffix.strip(".") + img_bytes = img_p.read_bytes() + np_img, _, infos = load_img(img_bytes, False, True) + res_pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos) + res_img = Image.open(io.BytesIO(res_pil_bytes)) + return infos, res_img.info, res_pil_bytes + + +def assert_keys(keys: List[str], infos, res_infos): + for k in keys: + assert k in infos + assert k in res_infos + assert infos[k] == res_infos[k] + + +def run_test(file_path, keys): + infos, res_infos, res_pil_bytes = extra_info(file_path) + assert_keys(keys, infos, res_infos) + with tempfile.NamedTemporaryFile("wb", suffix=file_path.suffix) as temp_file: + temp_file.write(res_pil_bytes) + temp_file.flush() + infos, res_infos, res_pil_bytes = extra_info(Path(temp_file.name)) + assert_keys(keys, infos, res_infos) + + +def test_png_icc_profile_png(): + run_test(current_dir / "icc_profile_test.png", ["icc_profile", "exif"]) + + +def test_png_icc_profile_jpeg(): + run_test(current_dir / "icc_profile_test.jpg", ["icc_profile", "exif"]) + + +def test_jpeg(): + jpg_img_p = current_dir / "bunny.jpeg" + run_test(jpg_img_p, ["dpi", "exif"]) + + +def test_png_parameter(): + jpg_img_p = current_dir / "png_parameter_test.png" + run_test(jpg_img_p, ["parameters"]) diff --git a/custom-demo/back-end/tests/test_sd_model.py b/custom-demo/back-end/tests/test_sd_model.py new file mode 100644 index 0000000..aa26c71 --- /dev/null +++ b/custom-demo/back-end/tests/test_sd_model.py @@ -0,0 +1,269 @@ +import os + +from loguru import logger + +from iopaint.tests.utils import check_device, get_config, assert_equal + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler, FREEUConfig + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +def test_runway_sd_1_5_all_samplers(device): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + + all_samplers = [member.value for member in SDSampler.__members__.values()] + print(all_samplers) + for sampler in all_samplers: + print(f"Testing sampler {sampler}") + if ( + sampler + in [SDSampler.dpm2_karras, SDSampler.dpm2_a_karras, SDSampler.lms_karras] + and device == "mps" + ): + # diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it + logger.warning( + "skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead." + ) + continue + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_sampler=sampler, + ) + + name = f"device_{device}_{sampler}" + + assert_equal( + model, + cfg, + f"runway_sd_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.lcm]) +def test_runway_sd_lcm_lora(device, sampler): + check_device(device) + + sd_steps = 5 + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=2, + sd_lcm_lora=True, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_1_5_lcm_lora_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_freeu(device, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_freeu=True, + sd_freeu_config=FREEUConfig(), + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_1_5_freeu_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_sd_strength(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=strategy, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_strength=0.8, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_strength_0.8_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_cpu_textencoder(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=True, + ) + cfg = get_config( + strategy=strategy, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_sampler=sampler, + ) + + assert_equal( + model, + cfg, + f"runway_sd_device_{device}_cpu_textencoder.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_norm_sd_model(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-v1-5", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_{device}_norm_sd_model_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m]) +def test_runway_sd_1_5_cpu_offload(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + ) + cfg = get_config( + strategy=strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps + ) + cfg.sd_sampler = sampler + + name = f"device_{device}_{sampler}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +@pytest.mark.parametrize( + "name", + [ + "sd-v1-5-inpainting.safetensors", + "v1-5-pruned-emaonly.safetensors", + "sd_xl_base_1.0.safetensors", + "sd_xl_base_1.0_inpainting_0.1.safetensors", + ], +) +def test_local_file_path(device, sampler, name): + sd_steps = check_device(device) + model = ModelManager( + name=name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + ) + cfg.sd_sampler = sampler + + name = f"device_{device}_{sampler}_{name}" + + is_sdxl = "sd_xl" in name + + assert_equal( + model, + cfg, + f"sd_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.5 if is_sdxl else 1, + fy=1.5 if is_sdxl else 1, + ) diff --git a/custom-demo/back-end/tests/test_sdxl.py b/custom-demo/back-end/tests/test_sdxl.py new file mode 100644 index 0000000..e236948 --- /dev/null +++ b/custom-demo/back-end/tests/test_sdxl.py @@ -0,0 +1,172 @@ +import os + +from iopaint.tests.utils import check_device, current_dir + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler, FREEUConfig +from iopaint.tests.test_model import get_config, assert_equal + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sdxl(device, strategy, sampler): + sd_steps = check_device(device) + + model = ModelManager( + name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=strategy, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_strength=1.0, + sd_guidance_scale=7.0, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"sdxl_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=2, + fy=2, + ) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sdxl_cpu_text_encoder(device, strategy, sampler): + sd_steps = check_device(device) + + model = ModelManager( + name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=True, + ) + cfg = get_config( + strategy=strategy, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_strength=1.0, + sd_guidance_scale=7.0, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"sdxl_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=2, + fy=2, + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler): + sd_steps = check_device(device) + + model = ModelManager( + name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=strategy, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_strength=1.0, + sd_guidance_scale=2.0, + sd_lcm_lora=True, + ) + cfg.sd_sampler = sampler + + name = f"device_{device}_{sampler}" + + assert_equal( + model, + cfg, + f"sdxl_{name}_lcm_lora.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=2, + fy=2, + ) + + cfg = get_config( + strategy=strategy, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_freeu=True, + sd_freeu_config=FREEUConfig(), + ) + + assert_equal( + model, + cfg, + f"sdxl_{name}_freeu_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=2, + fy=2, + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize( + "rect", + [ + [-128, -128, 1024, 1024], + ], +) +def test_sdxl_outpainting(device, rect): + sd_steps = check_device(device) + + model = ModelManager( + name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="a dog sitting on a bench in the park", + sd_steps=sd_steps, + use_extender=True, + extender_x=rect[0], + extender_y=rect[1], + extender_width=rect[2], + extender_height=rect[3], + sd_strength=1.0, + sd_guidance_scale=8.0, + sd_sampler=SDSampler.ddim, + ) + + assert_equal( + model, + cfg, + f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.5, + fy=1.5, + ) diff --git a/custom-demo/back-end/tests/utils.py b/custom-demo/back-end/tests/utils.py new file mode 100644 index 0000000..08f4aeb --- /dev/null +++ b/custom-demo/back-end/tests/utils.py @@ -0,0 +1,77 @@ +from pathlib import Path +import cv2 +import pytest +import torch + +from iopaint.helper import encode_pil_to_base64 +from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler +from PIL import Image + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) + + +def check_device(device: str) -> int: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skip test on cuda") + if device == "mps" and not torch.backends.mps.is_available(): + pytest.skip("mps is not available, skip test on mps") + steps = 2 if device == "cpu" else 20 + return steps + + +def assert_equal( + model, + config: InpaintRequest, + gt_name, + fx: float = 1, + fy: float = 1, + img_p=current_dir / "image.png", + mask_p=current_dir / "mask.png", +): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + print(f"Input image shape: {img.shape}") + res = model(img, mask, config) + ok = cv2.imwrite( + str(save_dir / gt_name), + res, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + ) + assert ok, save_dir / gt_name + + """ + Note that JPEG is lossy compression, so even if it is the highest quality 100, + when the saved images is reloaded, a difference occurs with the original pixel value. + If you want to save the original images as it is, save it as PNG or BMP. + """ + # gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED) + # assert np.array_equal(res, gt) + + +def get_data( + fx: float = 1, + fy: float = 1.0, + img_p=current_dir / "image.png", + mask_p=current_dir / "mask.png", +): + img = cv2.imread(str(img_p)) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) + mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST) + return img, mask + + +def get_config(**kwargs): + data = dict( + sd_sampler=kwargs.get("sd_sampler", SDSampler.uni_pc), + ldm_steps=1, + ldm_sampler=LDMSampler.plms, + hd_strategy=kwargs.get("strategy", HDStrategy.ORIGINAL), + hd_strategy_crop_margin=32, + hd_strategy_crop_trigger_size=200, + hd_strategy_resize_limit=200, + ) + data.update(**kwargs) + return InpaintRequest(image="", mask="", **data) diff --git a/custom-demo/back-end/web_app/.eslintrc.cjs b/custom-demo/back-end/web_app/.eslintrc.cjs new file mode 100644 index 0000000..d6c9537 --- /dev/null +++ b/custom-demo/back-end/web_app/.eslintrc.cjs @@ -0,0 +1,18 @@ +module.exports = { + root: true, + env: { browser: true, es2020: true }, + extends: [ + 'eslint:recommended', + 'plugin:@typescript-eslint/recommended', + 'plugin:react-hooks/recommended', + ], + ignorePatterns: ['dist', '.eslintrc.cjs'], + parser: '@typescript-eslint/parser', + plugins: ['react-refresh'], + rules: { + 'react-refresh/only-export-components': [ + 'warn', + { allowConstantExport: true }, + ], + }, +} diff --git a/custom-demo/back-end/web_app/.gitignore b/custom-demo/back-end/web_app/.gitignore new file mode 100644 index 0000000..a547bf3 --- /dev/null +++ b/custom-demo/back-end/web_app/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/custom-demo/back-end/web_app/README.md b/custom-demo/back-end/web_app/README.md new file mode 100644 index 0000000..0d6babe --- /dev/null +++ b/custom-demo/back-end/web_app/README.md @@ -0,0 +1,30 @@ +# React + TypeScript + Vite + +This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules. + +Currently, two official plugins are available: + +- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh +- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh + +## Expanding the ESLint configuration + +If you are developing a production application, we recommend updating the configuration to enable type aware lint rules: + +- Configure the top-level `parserOptions` property like this: + +```js +export default { + // other rules... + parserOptions: { + ecmaVersion: 'latest', + sourceType: 'module', + project: ['./tsconfig.json', './tsconfig.node.json'], + tsconfigRootDir: __dirname, + }, +} +``` + +- Replace `plugin:@typescript-eslint/recommended` to `plugin:@typescript-eslint/recommended-type-checked` or `plugin:@typescript-eslint/strict-type-checked` +- Optionally add `plugin:@typescript-eslint/stylistic-type-checked` +- Install [eslint-plugin-react](https://github.com/jsx-eslint/eslint-plugin-react) and add `plugin:react/recommended` & `plugin:react/jsx-runtime` to the `extends` list diff --git a/custom-demo/back-end/web_app/components.json b/custom-demo/back-end/web_app/components.json new file mode 100644 index 0000000..0cc3425 --- /dev/null +++ b/custom-demo/back-end/web_app/components.json @@ -0,0 +1,16 @@ +{ + "$schema": "https://ui.shadcn.com/schema.json", + "style": "new-york", + "rsc": false, + "tsx": true, + "tailwind": { + "config": "tailwind.config.js", + "css": "app/globals.css", + "baseColor": "gray", + "cssVariables": true + }, + "aliases": { + "components": "@/components", + "utils": "@/lib/utils" + } +} \ No newline at end of file diff --git a/custom-demo/back-end/web_app/index.html b/custom-demo/back-end/web_app/index.html new file mode 100644 index 0000000..01ffd1f --- /dev/null +++ b/custom-demo/back-end/web_app/index.html @@ -0,0 +1,12 @@ + + + + + + IOPaint + + +
+ + + diff --git a/custom-demo/back-end/web_app/package-lock.json b/custom-demo/back-end/web_app/package-lock.json new file mode 100644 index 0000000..283e3bc --- /dev/null +++ b/custom-demo/back-end/web_app/package-lock.json @@ -0,0 +1,6634 @@ +{ + "name": "web_app", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "web_app", + "version": "0.0.0", + "dependencies": { + "@heroicons/react": "^2.0.18", + "@hookform/resolvers": "^3.3.2", + "@radix-ui/react-accordion": "^1.1.2", + "@radix-ui/react-alert-dialog": "^1.0.5", + "@radix-ui/react-context-menu": "^2.1.5", + "@radix-ui/react-dialog": "^1.0.5", + "@radix-ui/react-dropdown-menu": "^2.0.6", + "@radix-ui/react-icons": "^1.3.0", + "@radix-ui/react-label": "^2.0.2", + "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.0.3", + "@radix-ui/react-radio-group": "^1.1.3", + "@radix-ui/react-scroll-area": "^1.0.5", + "@radix-ui/react-select": "^2.0.0", + "@radix-ui/react-separator": "^1.0.3", + "@radix-ui/react-slider": "^1.1.2", + "@radix-ui/react-slot": "^1.0.2", + "@radix-ui/react-switch": "^1.0.3", + "@radix-ui/react-tabs": "^1.0.4", + "@radix-ui/react-toast": "^1.1.5", + "@radix-ui/react-toggle": "^1.0.3", + "@radix-ui/react-tooltip": "^1.0.7", + "@tanstack/react-query": "^5.8.7", + "@uidotdev/usehooks": "^2.4.1", + "axios": "^1.11.0", + "class-variance-authority": "^0.7.0", + "clsx": "^2.0.0", + "fuse.js": "^7.0.0", + "immer": "^10.0.3", + "inter-ui": "^4.0.0", + "lodash": "^4.17.21", + "lucide-react": "^0.292.0", + "mitt": "^3.0.1", + "next-themes": "^0.2.1", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-hook-form": "^7.48.2", + "react-hotkeys-hook": "^4.4.1", + "react-photo-album": "^2.3.0", + "react-use": "^17.4.0", + "react-zoom-pan-pinch": "^3.3.0", + "recoil": "^0.7.7", + "socket.io-client": "^4.7.2", + "tailwind-merge": "^2.0.0", + "tailwindcss-animate": "^1.0.7", + "zod": "^3.22.4", + "zundo": "^2.0.0", + "zustand": "^4.4.6" + }, + "devDependencies": { + "@tanstack/eslint-plugin-query": "^5.8.4", + "@types/axios": "^0.14.4", + "@types/flexsearch": "^0.7.6", + "@types/lodash": "^4.14.201", + "@types/node": "^20.9.2", + "@types/react": "^18.2.37", + "@types/react-dom": "^18.2.15", + "@typescript-eslint/eslint-plugin": "^6.10.0", + "@typescript-eslint/parser": "^6.10.0", + "@vitejs/plugin-react": "^4.2.0", + "@vitejs/plugin-react-swc": "^3.5.0", + "autoprefixer": "^10.4.16", + "eslint": "^8.53.0", + "eslint-plugin-react-hooks": "^4.6.0", + "eslint-plugin-react-refresh": "^0.4.4", + "postcss": "^8.4.31", + "tailwindcss": "^3.3.5", + "typescript": "^5.2.2", + "vite": "^5.0.0" + } + }, + "node_modules/@aashutoshrathi/word-wrap": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz", + "integrity": "sha512-1Yjs2SvM8TflER/OD3cOjhWWOZb58A2t7wpE2S9XfBYTiIl+XFhQG2bjy4Pu1I+EAlCNUzRDYDdFwFYUKvXcIA==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/@alloc/quick-lru": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", + "integrity": "sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@ampproject/remapping": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.2.1.tgz", + "integrity": "sha512-lFMjJTrFL3j7L9yBxwYfCq2k6qqwHyzuUl/XBnif78PWTJYyL/dfowQHWE3sp6U6ZzqWiiIZnpTMO96zhkjwtg==", + "dev": true, + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.0", + "@jridgewell/trace-mapping": "^0.3.9" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/code-frame": { + "version": "7.22.13", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.22.13.tgz", + "integrity": "sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==", + "dev": true, + "dependencies": { + "@babel/highlight": "^7.22.13", + "chalk": "^2.4.2" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/code-frame/node_modules/ansi-styles": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", + "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "dev": true, + "dependencies": { + "color-convert": "^1.9.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/code-frame/node_modules/chalk": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", + "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", + "dev": true, + "dependencies": { + "ansi-styles": "^3.2.1", + "escape-string-regexp": "^1.0.5", + "supports-color": "^5.3.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/code-frame/node_modules/color-convert": { + "version": "1.9.3", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", + "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", + "dev": true, + "dependencies": { + "color-name": "1.1.3" + } + }, + "node_modules/@babel/code-frame/node_modules/color-name": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", + "integrity": "sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==", + "dev": true + }, + "node_modules/@babel/code-frame/node_modules/escape-string-regexp": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", + "integrity": "sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==", + "dev": true, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/@babel/code-frame/node_modules/has-flag": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", + "integrity": "sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/code-frame/node_modules/supports-color": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", + "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", + "dev": true, + "dependencies": { + "has-flag": "^3.0.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/compat-data": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.23.3.tgz", + "integrity": "sha512-BmR4bWbDIoFJmJ9z2cZ8Gmm2MXgEDgjdWgpKmKWUt54UGFJdlj31ECtbaDvCG/qVdG3AQ1SfpZEs01lUFbzLOQ==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/core": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.23.3.tgz", + "integrity": "sha512-Jg+msLuNuCJDyBvFv5+OKOUjWMZgd85bKjbICd3zWrKAo+bJ49HJufi7CQE0q0uR8NGyO6xkCACScNqyjHSZew==", + "dev": true, + "dependencies": { + "@ampproject/remapping": "^2.2.0", + "@babel/code-frame": "^7.22.13", + "@babel/generator": "^7.23.3", + "@babel/helper-compilation-targets": "^7.22.15", + "@babel/helper-module-transforms": "^7.23.3", + "@babel/helpers": "^7.23.2", + "@babel/parser": "^7.23.3", + "@babel/template": "^7.22.15", + "@babel/traverse": "^7.23.3", + "@babel/types": "^7.23.3", + "convert-source-map": "^2.0.0", + "debug": "^4.1.0", + "gensync": "^1.0.0-beta.2", + "json5": "^2.2.3", + "semver": "^6.3.1" + }, + "engines": { + "node": ">=6.9.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/babel" + } + }, + "node_modules/@babel/core/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/@babel/generator": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.23.3.tgz", + "integrity": "sha512-keeZWAV4LU3tW0qRi19HRpabC/ilM0HRBBzf9/k8FFiG4KVpiv0FIy4hHfLfFQZNhziCTPTmd59zoyv6DNISzg==", + "dev": true, + "dependencies": { + "@babel/types": "^7.23.3", + "@jridgewell/gen-mapping": "^0.3.2", + "@jridgewell/trace-mapping": "^0.3.17", + "jsesc": "^2.5.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-compilation-targets": { + "version": "7.22.15", + "resolved": "https://registry.npmjs.org/@babel/helper-compilation-targets/-/helper-compilation-targets-7.22.15.tgz", + "integrity": "sha512-y6EEzULok0Qvz8yyLkCvVX+02ic+By2UdOhylwUOvOn9dvYc9mKICJuuU1n1XBI02YWsNsnrY1kc6DVbjcXbtw==", + "dev": true, + "dependencies": { + "@babel/compat-data": "^7.22.9", + "@babel/helper-validator-option": "^7.22.15", + "browserslist": "^4.21.9", + "lru-cache": "^5.1.1", + "semver": "^6.3.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-compilation-targets/node_modules/lru-cache": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-5.1.1.tgz", + "integrity": "sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==", + "dev": true, + "dependencies": { + "yallist": "^3.0.2" + } + }, + "node_modules/@babel/helper-compilation-targets/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/@babel/helper-compilation-targets/node_modules/yallist": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", + "integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==", + "dev": true + }, + "node_modules/@babel/helper-environment-visitor": { + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz", + "integrity": "sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-function-name": { + "version": "7.23.0", + "resolved": "https://registry.npmjs.org/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz", + "integrity": "sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw==", + "dev": true, + "dependencies": { + "@babel/template": "^7.22.15", + "@babel/types": "^7.23.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-hoist-variables": { + "version": "7.22.5", + "resolved": "https://registry.npmjs.org/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz", + "integrity": "sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw==", + "dev": true, + "dependencies": { + "@babel/types": "^7.22.5" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-module-imports": { + "version": "7.22.15", + "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.22.15.tgz", + "integrity": "sha512-0pYVBnDKZO2fnSPCrgM/6WMc7eS20Fbok+0r88fp+YtWVLZrp4CkafFGIp+W0VKw4a22sgebPT99y+FDNMdP4w==", + "dev": true, + "dependencies": { + "@babel/types": "^7.22.15" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-module-transforms": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/helper-module-transforms/-/helper-module-transforms-7.23.3.tgz", + "integrity": "sha512-7bBs4ED9OmswdfDzpz4MpWgSrV7FXlc3zIagvLFjS5H+Mk7Snr21vQ6QwrsoCGMfNC4e4LQPdoULEt4ykz0SRQ==", + "dev": true, + "dependencies": { + "@babel/helper-environment-visitor": "^7.22.20", + "@babel/helper-module-imports": "^7.22.15", + "@babel/helper-simple-access": "^7.22.5", + "@babel/helper-split-export-declaration": "^7.22.6", + "@babel/helper-validator-identifier": "^7.22.20" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0" + } + }, + "node_modules/@babel/helper-plugin-utils": { + "version": "7.22.5", + "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.22.5.tgz", + "integrity": "sha512-uLls06UVKgFG9QD4OeFYLEGteMIAa5kpTPcFL28yuCIIzsf6ZyKZMllKVOCZFhiZ5ptnwX4mtKdWCBE/uT4amg==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-simple-access": { + "version": "7.22.5", + "resolved": "https://registry.npmjs.org/@babel/helper-simple-access/-/helper-simple-access-7.22.5.tgz", + "integrity": "sha512-n0H99E/K+Bika3++WNL17POvo4rKWZ7lZEp1Q+fStVbUi8nxPQEBOlTmCOxW/0JsS56SKKQ+ojAe2pHKJHN35w==", + "dev": true, + "dependencies": { + "@babel/types": "^7.22.5" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-split-export-declaration": { + "version": "7.22.6", + "resolved": "https://registry.npmjs.org/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz", + "integrity": "sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g==", + "dev": true, + "dependencies": { + "@babel/types": "^7.22.5" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-string-parser": { + "version": "7.22.5", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz", + "integrity": "sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz", + "integrity": "sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-option": { + "version": "7.22.15", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-option/-/helper-validator-option-7.22.15.tgz", + "integrity": "sha512-bMn7RmyFjY/mdECUbgn9eoSY4vqvacUnS9i9vGAGttgFWesO6B4CYWA7XlpbWgBt71iv/hfbPlynohStqnu5hA==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helpers": { + "version": "7.23.2", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.23.2.tgz", + "integrity": "sha512-lzchcp8SjTSVe/fPmLwtWVBFC7+Tbn8LGHDVfDp9JGxpAY5opSaEFgt8UQvrnECWOTdji2mOWMz1rOhkHscmGQ==", + "dev": true, + "dependencies": { + "@babel/template": "^7.22.15", + "@babel/traverse": "^7.23.2", + "@babel/types": "^7.23.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/highlight": { + "version": "7.22.20", + "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.22.20.tgz", + "integrity": "sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg==", + "dev": true, + "dependencies": { + "@babel/helper-validator-identifier": "^7.22.20", + "chalk": "^2.4.2", + "js-tokens": "^4.0.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/highlight/node_modules/ansi-styles": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", + "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "dev": true, + "dependencies": { + "color-convert": "^1.9.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/highlight/node_modules/chalk": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", + "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", + "dev": true, + "dependencies": { + "ansi-styles": "^3.2.1", + "escape-string-regexp": "^1.0.5", + "supports-color": "^5.3.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/highlight/node_modules/color-convert": { + "version": "1.9.3", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", + "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", + "dev": true, + "dependencies": { + "color-name": "1.1.3" + } + }, + "node_modules/@babel/highlight/node_modules/color-name": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", + "integrity": "sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==", + "dev": true + }, + "node_modules/@babel/highlight/node_modules/escape-string-regexp": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", + "integrity": "sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==", + "dev": true, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/@babel/highlight/node_modules/has-flag": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", + "integrity": "sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/highlight/node_modules/supports-color": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", + "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", + "dev": true, + "dependencies": { + "has-flag": "^3.0.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/parser": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.23.3.tgz", + "integrity": "sha512-uVsWNvlVsIninV2prNz/3lHCb+5CJ+e+IUBfbjToAHODtfGYLfCFuY4AU7TskI+dAKk+njsPiBjq1gKTvZOBaw==", + "dev": true, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/plugin-transform-react-jsx-self": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx-self/-/plugin-transform-react-jsx-self-7.23.3.tgz", + "integrity": "sha512-qXRvbeKDSfwnlJnanVRp0SfuWE5DQhwQr5xtLBzp56Wabyo+4CMosF6Kfp+eOD/4FYpql64XVJ2W0pVLlJZxOQ==", + "dev": true, + "dependencies": { + "@babel/helper-plugin-utils": "^7.22.5" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, + "node_modules/@babel/plugin-transform-react-jsx-source": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx-source/-/plugin-transform-react-jsx-source-7.23.3.tgz", + "integrity": "sha512-91RS0MDnAWDNvGC6Wio5XYkyWI39FMFO+JK9+4AlgaTH+yWwVTsw7/sn6LK0lH7c5F+TFkpv/3LfCJ1Ydwof/g==", + "dev": true, + "dependencies": { + "@babel/helper-plugin-utils": "^7.22.5" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.23.2", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.23.2.tgz", + "integrity": "sha512-mM8eg4yl5D6i3lu2QKPuPH4FArvJ8KhTofbE7jwMUv9KX5mBvwPAqnV3MlyBNqdp9RyRKP6Yck8TrfYrPvX3bg==", + "dependencies": { + "regenerator-runtime": "^0.14.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/template": { + "version": "7.22.15", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.22.15.tgz", + "integrity": "sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==", + "dev": true, + "dependencies": { + "@babel/code-frame": "^7.22.13", + "@babel/parser": "^7.22.15", + "@babel/types": "^7.22.15" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/traverse": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.3.tgz", + "integrity": "sha512-+K0yF1/9yR0oHdE0StHuEj3uTPzwwbrLGfNOndVJVV2TqA5+j3oljJUb4nmB954FLGjNem976+B+eDuLIjesiQ==", + "dev": true, + "dependencies": { + "@babel/code-frame": "^7.22.13", + "@babel/generator": "^7.23.3", + "@babel/helper-environment-visitor": "^7.22.20", + "@babel/helper-function-name": "^7.23.0", + "@babel/helper-hoist-variables": "^7.22.5", + "@babel/helper-split-export-declaration": "^7.22.6", + "@babel/parser": "^7.23.3", + "@babel/types": "^7.23.3", + "debug": "^4.1.0", + "globals": "^11.1.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/traverse/node_modules/globals": { + "version": "11.12.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-11.12.0.tgz", + "integrity": "sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/@babel/types": { + "version": "7.23.3", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.23.3.tgz", + "integrity": "sha512-OZnvoH2l8PK5eUvEcUyCt/sXgr/h+UWpVuBbOljwcrAgUl6lpchoQ++PHGyQy1AtYnVA6CEq3y5xeEI10brpXw==", + "dev": true, + "dependencies": { + "@babel/helper-string-parser": "^7.22.5", + "@babel/helper-validator-identifier": "^7.22.20", + "to-fast-properties": "^2.0.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.19.6.tgz", + "integrity": "sha512-muPzBqXJKCbMYoNbb1JpZh/ynl0xS6/+pLjrofcR3Nad82SbsCogYzUE6Aq9QT3cLP0jR/IVK/NHC9b90mSHtg==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.19.6.tgz", + "integrity": "sha512-KQ/hbe9SJvIJ4sR+2PcZ41IBV+LPJyYp6V1K1P1xcMRup9iYsBoQn4MzE3mhMLOld27Au2eDcLlIREeKGUXpHQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.19.6.tgz", + "integrity": "sha512-VVJVZQ7p5BBOKoNxd0Ly3xUM78Y4DyOoFKdkdAe2m11jbh0LEU4bPles4e/72EMl4tapko8o915UalN/5zhspg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.19.6.tgz", + "integrity": "sha512-91LoRp/uZAKx6ESNspL3I46ypwzdqyDLXZH7x2QYCLgtnaU08+AXEbabY2yExIz03/am0DivsTtbdxzGejfXpA==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.19.6.tgz", + "integrity": "sha512-QCGHw770ubjBU1J3ZkFJh671MFajGTYMZumPs9E/rqU52md6lIil97BR0CbPq6U+vTh3xnTNDHKRdR8ggHnmxQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.19.6.tgz", + "integrity": "sha512-J53d0jGsDcLzWk9d9SPmlyF+wzVxjXpOH7jVW5ae7PvrDst4kiAz6sX+E8btz0GB6oH12zC+aHRD945jdjF2Vg==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.19.6.tgz", + "integrity": "sha512-hn9qvkjHSIB5Z9JgCCjED6YYVGCNpqB7dEGavBdG6EjBD8S/UcNUIlGcB35NCkMETkdYwfZSvD9VoDJX6VeUVA==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.19.6.tgz", + "integrity": "sha512-G8IR5zFgpXad/Zp7gr7ZyTKyqZuThU6z1JjmRyN1vSF8j0bOlGzUwFSMTbctLAdd7QHpeyu0cRiuKrqK1ZTwvQ==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.19.6.tgz", + "integrity": "sha512-HQCOrk9XlH3KngASLaBfHpcoYEGUt829A9MyxaI8RMkfRA8SakG6YQEITAuwmtzFdEu5GU4eyhKcpv27dFaOBg==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.19.6.tgz", + "integrity": "sha512-22eOR08zL/OXkmEhxOfshfOGo8P69k8oKHkwkDrUlcB12S/sw/+COM4PhAPT0cAYW/gpqY2uXp3TpjQVJitz7w==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.19.6.tgz", + "integrity": "sha512-82RvaYAh/SUJyjWA8jDpyZCHQjmEggL//sC7F3VKYcBMumQjUL3C5WDl/tJpEiKtt7XrWmgjaLkrk205zfvwTA==", + "cpu": [ + "loong64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.19.6.tgz", + "integrity": "sha512-8tvnwyYJpR618vboIv2l8tK2SuK/RqUIGMfMENkeDGo3hsEIrpGldMGYFcWxWeEILe5Fi72zoXLmhZ7PR23oQA==", + "cpu": [ + "mips64el" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.19.6.tgz", + "integrity": "sha512-Qt+D7xiPajxVNk5tQiEJwhmarNnLPdjXAoA5uWMpbfStZB0+YU6a3CtbWYSy+sgAsnyx4IGZjWsTzBzrvg/fMA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.19.6.tgz", + "integrity": "sha512-lxRdk0iJ9CWYDH1Wpnnnc640ajF4RmQ+w6oHFZmAIYu577meE9Ka/DCtpOrwr9McMY11ocbp4jirgGgCi7Ls/g==", + "cpu": [ + "riscv64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.19.6.tgz", + "integrity": "sha512-MopyYV39vnfuykHanRWHGRcRC3AwU7b0QY4TI8ISLfAGfK+tMkXyFuyT1epw/lM0pflQlS53JoD22yN83DHZgA==", + "cpu": [ + "s390x" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.19.6.tgz", + "integrity": "sha512-UWcieaBzsN8WYbzFF5Jq7QULETPcQvlX7KL4xWGIB54OknXJjBO37sPqk7N82WU13JGWvmDzFBi1weVBajPovg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.19.6.tgz", + "integrity": "sha512-EpWiLX0fzvZn1wxtLxZrEW+oQED9Pwpnh+w4Ffv8ZLuMhUoqR9q9rL4+qHW8F4Mg5oQEKxAoT0G+8JYNqCiR6g==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.19.6.tgz", + "integrity": "sha512-fFqTVEktM1PGs2sLKH4M5mhAVEzGpeZJuasAMRnvDZNCV0Cjvm1Hu35moL2vC0DOrAQjNTvj4zWrol/lwQ8Deg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.19.6.tgz", + "integrity": "sha512-M+XIAnBpaNvaVAhbe3uBXtgWyWynSdlww/JNZws0FlMPSBy+EpatPXNIlKAdtbFVII9OpX91ZfMb17TU3JKTBA==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.19.6.tgz", + "integrity": "sha512-2DchFXn7vp/B6Tc2eKdTsLzE0ygqKkNUhUBCNtMx2Llk4POIVMUq5rUYjdcedFlGLeRe1uLCpVvCmE+G8XYybA==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.19.6.tgz", + "integrity": "sha512-PBo/HPDQllyWdjwAVX+Gl2hH0dfBydL97BAH/grHKC8fubqp02aL4S63otZ25q3sBdINtOBbz1qTZQfXbP4VBg==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.19.6.tgz", + "integrity": "sha512-OE7yIdbDif2kKfrGa+V0vx/B3FJv2L4KnIiLlvtibPyO9UkgO3rzYE0HhpREo2vmJ1Ixq1zwm9/0er+3VOSZJA==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@eslint-community/eslint-utils": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", + "integrity": "sha512-1/sA4dwrzBAyeUoQ6oxahHKmrZvsnLCg4RfxW3ZFGGmQkSNQPFNLV9CUEFQP1x9EYXHTo5p6xdhZM1Ne9p/AfA==", + "dev": true, + "dependencies": { + "eslint-visitor-keys": "^3.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" + } + }, + "node_modules/@eslint-community/regexpp": { + "version": "4.10.0", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.10.0.tgz", + "integrity": "sha512-Cu96Sd2By9mCNTx2iyKOmq10v22jUVQv0lQnlGNy16oE9589yE+QADPbrMGCkA51cKZSg3Pu/aTJVTGfL/qjUA==", + "dev": true, + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, + "node_modules/@eslint/eslintrc": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.3.tgz", + "integrity": "sha512-yZzuIG+jnVu6hNSzFEN07e8BxF3uAzYtQb6uDkaYZLo6oYZDCq454c5kB8zxnzfCYyP4MIuyBn10L0DqwujTmA==", + "dev": true, + "dependencies": { + "ajv": "^6.12.4", + "debug": "^4.3.2", + "espree": "^9.6.0", + "globals": "^13.19.0", + "ignore": "^5.2.0", + "import-fresh": "^3.2.1", + "js-yaml": "^4.1.0", + "minimatch": "^3.1.2", + "strip-json-comments": "^3.1.1" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint/js": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.54.0.tgz", + "integrity": "sha512-ut5V+D+fOoWPgGGNj83GGjnntO39xDy6DWxO0wb7Jp3DcMX0TfIqdzHF85VTQkerdyGmuuMD9AKAo5KiNlf/AQ==", + "dev": true, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/@floating-ui/core": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.5.0.tgz", + "integrity": "sha512-kK1h4m36DQ0UHGj5Ah4db7R0rHemTqqO0QLvUqi1/mUUp3LuAWbWxdxSIf/XsnH9VS6rRVPLJCncjRzUvyCLXg==", + "dependencies": { + "@floating-ui/utils": "^0.1.3" + } + }, + "node_modules/@floating-ui/dom": { + "version": "1.5.3", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.5.3.tgz", + "integrity": "sha512-ClAbQnEqJAKCJOEbbLo5IUlZHkNszqhuxS4fHAVxRPXPya6Ysf2G8KypnYcOTpx6I8xcgF9bbHb6g/2KpbV8qA==", + "dependencies": { + "@floating-ui/core": "^1.4.2", + "@floating-ui/utils": "^0.1.3" + } + }, + "node_modules/@floating-ui/react-dom": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.0.4.tgz", + "integrity": "sha512-CF8k2rgKeh/49UrnIBs4BdxPUV6vize/Db1d/YbCLyp9GiVZ0BEwf5AiDSxJRCr6yOkGqTFHtmrULxkEfYZ7dQ==", + "dependencies": { + "@floating-ui/dom": "^1.5.1" + }, + "peerDependencies": { + "react": ">=16.8.0", + "react-dom": ">=16.8.0" + } + }, + "node_modules/@floating-ui/utils": { + "version": "0.1.6", + "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.1.6.tgz", + "integrity": "sha512-OfX7E2oUDYxtBvsuS4e/jSn4Q9Qb6DzgeYtsAdkPZ47znpoNsMgZw0+tVijiv3uGNR6dgNlty6r9rzIzHjtd/A==" + }, + "node_modules/@heroicons/react": { + "version": "2.0.18", + "resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.18.tgz", + "integrity": "sha512-7TyMjRrZZMBPa+/5Y8lN0iyvUU/01PeMGX2+RE7cQWpEUIcb4QotzUObFkJDejj/HUH4qjP/eQ0gzzKs2f+6Yw==", + "peerDependencies": { + "react": ">= 16" + } + }, + "node_modules/@hookform/resolvers": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/@hookform/resolvers/-/resolvers-3.3.2.tgz", + "integrity": "sha512-Tw+GGPnBp+5DOsSg4ek3LCPgkBOuOgS5DsDV7qsWNH9LZc433kgsWICjlsh2J9p04H2K66hsXPPb9qn9ILdUtA==", + "peerDependencies": { + "react-hook-form": "^7.0.0" + } + }, + "node_modules/@humanwhocodes/config-array": { + "version": "0.11.13", + "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz", + "integrity": "sha512-JSBDMiDKSzQVngfRjOdFXgFfklaXI4K9nLF49Auh21lmBWRLIK3+xTErTWD4KU54pb6coM6ESE7Awz/FNU3zgQ==", + "dev": true, + "dependencies": { + "@humanwhocodes/object-schema": "^2.0.1", + "debug": "^4.1.1", + "minimatch": "^3.0.5" + }, + "engines": { + "node": ">=10.10.0" + } + }, + "node_modules/@humanwhocodes/module-importer": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz", + "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==", + "dev": true, + "engines": { + "node": ">=12.22" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/object-schema": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/object-schema/-/object-schema-2.0.1.tgz", + "integrity": "sha512-dvuCeX5fC9dXgJn9t+X5atfmgQAzUOWqS1254Gh0m6i8wKd10ebXkfNKiRK+1GWi/yTvvLDHpoxLr0xxxeslWw==", + "dev": true + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.3", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz", + "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==", + "dependencies": { + "@jridgewell/set-array": "^1.0.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.9" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz", + "integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/set-array": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", + "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.4.15", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz", + "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.20", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.20.tgz", + "integrity": "sha512-R8LcPeWZol2zR8mmH3JeKQ6QRCFb7XgUhV9ZlGhHLGyg4wpPiPZNQOOWhFZhxKw8u//yTbNGI42Bx/3paXEQ+Q==", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@next/env": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/env/-/env-14.0.3.tgz", + "integrity": "sha512-7xRqh9nMvP5xrW4/+L0jgRRX+HoNRGnfJpD+5Wq6/13j3dsdzxO3BCXn7D3hMqsDb+vjZnJq+vI7+EtgrYZTeA==", + "peer": true + }, + "node_modules/@next/swc-darwin-arm64": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.0.3.tgz", + "integrity": "sha512-64JbSvi3nbbcEtyitNn2LEDS/hcleAFpHdykpcnrstITFlzFgB/bW0ER5/SJJwUPj+ZPY+z3e+1jAfcczRLVGw==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "darwin" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-darwin-x64": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.0.3.tgz", + "integrity": "sha512-RkTf+KbAD0SgYdVn1XzqE/+sIxYGB7NLMZRn9I4Z24afrhUpVJx6L8hsRnIwxz3ERE2NFURNliPjJ2QNfnWicQ==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "darwin" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-gnu": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.0.3.tgz", + "integrity": "sha512-3tBWGgz7M9RKLO6sPWC6c4pAw4geujSwQ7q7Si4d6bo0l6cLs4tmO+lnSwFp1Tm3lxwfMk0SgkJT7EdwYSJvcg==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-musl": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.0.3.tgz", + "integrity": "sha512-v0v8Kb8j8T23jvVUWZeA2D8+izWspeyeDGNaT2/mTHWp7+37fiNfL8bmBWiOmeumXkacM/AB0XOUQvEbncSnHA==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-gnu": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.0.3.tgz", + "integrity": "sha512-VM1aE1tJKLBwMGtyBR21yy+STfl0MapMQnNrXkxeyLs0GFv/kZqXS5Jw/TQ3TSUnbv0QPDf/X8sDXuMtSgG6eg==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-musl": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.0.3.tgz", + "integrity": "sha512-64EnmKy18MYFL5CzLaSuUn561hbO1Gk16jM/KHznYP3iCIfF9e3yULtHaMy0D8zbHfxset9LTOv6cuYKJgcOxg==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-arm64-msvc": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.0.3.tgz", + "integrity": "sha512-WRDp8QrmsL1bbGtsh5GqQ/KWulmrnMBgbnb+59qNTW1kVi1nG/2ndZLkcbs2GX7NpFLlToLRMWSQXmPzQm4tog==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "win32" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-ia32-msvc": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.0.3.tgz", + "integrity": "sha512-EKffQeqCrj+t6qFFhIFTRoqb2QwX1mU7iTOvMyLbYw3QtqTw9sMwjykyiMlZlrfm2a4fA84+/aeW+PMg1MjuTg==", + "cpu": [ + "ia32" + ], + "optional": true, + "os": [ + "win32" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-x64-msvc": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.0.3.tgz", + "integrity": "sha512-ERhKPSJ1vQrPiwrs15Pjz/rvDHZmkmvbf/BjPN/UCOI++ODftT0GtasDPi0j+y6PPJi5HsXw+dpRaXUaw4vjuQ==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "win32" + ], + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@radix-ui/number": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/number/-/number-1.0.1.tgz", + "integrity": "sha512-T5gIdVO2mmPW3NNhjNgEP3cqMXjXL9UbO0BzWcXfvdBs+BohbQxvd/K5hSVKmn9/lbTdsQVKbUcP5WLCwvUbBg==", + "dependencies": { + "@babel/runtime": "^7.13.10" + } + }, + "node_modules/@radix-ui/primitive": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.0.1.tgz", + "integrity": "sha512-yQ8oGX2GVsEYMWGxcovu1uGWPCxV5BFfeeYxqPmuAzUyLT9qmaMXSAhXpb0WrspIeqYzdJpkh2vHModJPgRIaw==", + "dependencies": { + "@babel/runtime": "^7.13.10" + } + }, + "node_modules/@radix-ui/react-accordion": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-accordion/-/react-accordion-1.1.2.tgz", + "integrity": "sha512-fDG7jcoNKVjSK6yfmuAs0EnPDro0WMXIhMtXdTBWqEioVW206ku+4Lw07e+13lUkFkpoEQ2PdeMIAGpdqEAmDg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-collapsible": "1.0.3", + "@radix-ui/react-collection": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-alert-dialog": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.0.5.tgz", + "integrity": "sha512-OrVIOcZL0tl6xibeuGt5/+UxoT2N27KCFOPjFyfXMnchxSHZ/OW7cCX2nGlIYJrbHK/fczPcFzAwvNBB6XBNMA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dialog": "1.0.5", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-arrow": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.0.3.tgz", + "integrity": "sha512-wSP+pHsB/jQRaL6voubsQ/ZlrGBHHrOjmBnr19hxYgtS0WvAFwZhK2WP/YY5yF9uKECCEEDGxuLxq1NBK51wFA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-collapsible": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.0.3.tgz", + "integrity": "sha512-UBmVDkmR6IvDsloHVN+3rtx4Mi5TFvylYXpluuv0f37dtaz3H99bp8No0LGXRigVpl3UAT4l9j6bIchh42S/Gg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-collection": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.0.3.tgz", + "integrity": "sha512-3SzW+0PW7yBBoQlT8wNcGtaxaD0XSu0uLUFgrtHY08Acx05TaHaOmVLR73c0j/cqpDy53KBMO7s0dx2wmOIDIA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-compose-refs": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.0.1.tgz", + "integrity": "sha512-fDSBgd44FKHa1FRMU59qBMPFcl2PZE+2nmqunj+BWFyYYjnhIDWL2ItDs3rrbJDQOtzt5nIebLCQc4QRfz6LJw==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-context": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.0.1.tgz", + "integrity": "sha512-ebbrdFoYTcuZ0v4wG5tedGnp9tzcV8awzsxYph7gXUyvnNLuTIcCk1q17JEbnVhXAKG9oX3KtchwiMIAYp9NLg==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-context-menu": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.1.5.tgz", + "integrity": "sha512-R5XaDj06Xul1KGb+WP8qiOh7tKJNz2durpLBXAGZjSVtctcRFCuEvy2gtMwRJGePwQQE5nV77gs4FwRi8T+r2g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-menu": "2.0.6", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz", + "integrity": "sha512-GjWJX/AUpB703eEBanuBnIWdIXg6NvJFCXcNlSZk4xdszCdhrJgBoUd1cGk67vFO+WdA2pfI/plOpqz/5GUP6Q==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-focus-guards": "1.0.1", + "@radix-ui/react-focus-scope": "1.0.4", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-controllable-state": "1.0.1", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.5.5" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-direction": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-direction/-/react-direction-1.0.1.tgz", + "integrity": "sha512-RXcvnXgyvYvBEOhCBuddKecVkoMiI10Jcm5cTI7abJRAHYfFxeu+FBQs/DvdxSYucxR5mna0dNsL6QFlds5TMA==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz", + "integrity": "sha512-aJeDjQhywg9LBu2t/At58hCvr7pEm0o2Ke1x33B+MhjNmmZ17sy4KImo0KPLgsnc/zN7GPdce8Cnn0SWvwZO7g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-escape-keydown": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dropdown-menu": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.0.6.tgz", + "integrity": "sha512-i6TuFOoWmLWq+M/eCLGd/bQ2HfAX1RJgvrBQ6AQLmzfvsLdefxbWu8G9zczcPFfcSPehz9GcpF6K9QYreFV8hA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-menu": "2.0.6", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-focus-guards": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.1.tgz", + "integrity": "sha512-Rect2dWbQ8waGzhMavsIbmSVCgYxkXLxxR3ZvCX79JOglzdEy4JXMb98lq4hPxUbLr77nP0UOGf4rcMU+s1pUA==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-focus-scope": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.4.tgz", + "integrity": "sha512-sL04Mgvf+FmyvZeYfNu1EPAaaxD+aw7cYeIB9L9Fvq8+urhltTRaEo5ysKOpHuKPclsZcSUMKlN05x4u+CINpA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-icons": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-icons/-/react-icons-1.3.0.tgz", + "integrity": "sha512-jQxj/0LKgp+j9BiTXz3O3sgs26RNet2iLWmsPyRz2SIcR4q/4SbazXfnYwbAr+vLYKSfc7qxzyGQA1HLlYiuNw==", + "peerDependencies": { + "react": "^16.x || ^17.x || ^18.x" + } + }, + "node_modules/@radix-ui/react-id": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-id/-/react-id-1.0.1.tgz", + "integrity": "sha512-tI7sT/kqYp8p96yGWY1OAnLHrqDgzHefRBKQ2YAkBS5ja7QLcZ9Z/uY7bEjPUatf8RomoXM8/1sMj1IJaE5UzQ==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-label": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.0.2.tgz", + "integrity": "sha512-N5ehvlM7qoTLx7nWPodsPYPgMzA5WM8zZChQg8nyFJKnDO5WHdba1vv5/H6IO5LtJMfD2Q3wh1qHFGNtK0w3bQ==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-menu": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.0.6.tgz", + "integrity": "sha512-BVkFLS+bUC8HcImkRKPSiVumA1VPOOEC5WBMiT+QAVsPzW1FJzI9KnqgGxVDPBcql5xXrHkD3JOVoXWEXD8SYA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-collection": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-focus-guards": "1.0.1", + "@radix-ui/react-focus-scope": "1.0.4", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-popper": "1.1.3", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-roving-focus": "1.0.4", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-callback-ref": "1.0.1", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.5.5" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popover": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.0.7.tgz", + "integrity": "sha512-shtvVnlsxT6faMnK/a7n0wptwBD23xc1Z5mdrtKLwVEfsEMXodS0r5s0/g5P0hX//EKYZS2sxUjqfzlg52ZSnQ==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-focus-guards": "1.0.1", + "@radix-ui/react-focus-scope": "1.0.4", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-popper": "1.1.3", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-controllable-state": "1.0.1", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.5.5" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popper": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.1.3.tgz", + "integrity": "sha512-cKpopj/5RHZWjrbF2846jBNacjQVwkP068DfmgrNJXpvVWrOvlAmE9xSiy5OqeE+Gi8D9fP+oDhUnPqNMY8/5w==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@floating-ui/react-dom": "^2.0.0", + "@radix-ui/react-arrow": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1", + "@radix-ui/react-use-rect": "1.0.1", + "@radix-ui/react-use-size": "1.0.1", + "@radix-ui/rect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-portal": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.0.4.tgz", + "integrity": "sha512-Qki+C/EuGUVCQTOTD5vzJzJuMUlewbzuKyUy+/iHM2uwGiru9gZeBJtHAPKAEkB5KWGi9mP/CHKcY0wt1aW45Q==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-presence": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.0.1.tgz", + "integrity": "sha512-UXLW4UAbIY5ZjcvzjfRFo5gxva8QirC9hF7wRE4U5gz+TP0DbRk+//qyuAQ1McDxBt1xNMBTaciFGvEmJvAZCg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-primitive": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-1.0.3.tgz", + "integrity": "sha512-yi58uVyoAcK/Nq1inRY56ZSjKypBNKTa/1mcL8qdl6oJeEaDbOldlzrGn7P6Q3Id5d+SYNGc5AJgc4vGhjs5+g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-slot": "1.0.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-progress": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.0.3.tgz", + "integrity": "sha512-5G6Om/tYSxjSeEdrb1VfKkfZfn/1IlPWd731h2RfPuSbIfNUgfqAwbKfJCg/PP6nuUCTrYzalwHSpSinoWoCag==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-radio-group": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.1.3.tgz", + "integrity": "sha512-x+yELayyefNeKeTx4fjK6j99Fs6c4qKm3aY38G3swQVTN6xMpsrbigC0uHs2L//g8q4qR7qOcww8430jJmi2ag==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-roving-focus": "1.0.4", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-use-previous": "1.0.1", + "@radix-ui/react-use-size": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-roving-focus": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.0.4.tgz", + "integrity": "sha512-2mUg5Mgcu001VkGy+FfzZyzbmuUWzgWkj3rvv4yu+mLw03+mTzbxZHvfcGyFp2b8EkQeMkpRQ5FiA2Vr2O6TeQ==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-collection": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-scroll-area": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-scroll-area/-/react-scroll-area-1.0.5.tgz", + "integrity": "sha512-b6PAgH4GQf9QEn8zbT2XUHpW5z8BzqEc7Kl11TwDrvuTrxlkcjTD5qa/bxgKr+nmuXKu4L/W5UZ4mlP/VG/5Gw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/number": "1.0.1", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-select": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.0.0.tgz", + "integrity": "sha512-RH5b7af4oHtkcHS7pG6Sgv5rk5Wxa7XI8W5gvB1N/yiuDGZxko1ynvOiVhFM7Cis2A8zxF9bTOUVbRDzPepe6w==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/number": "1.0.1", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-collection": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-focus-guards": "1.0.1", + "@radix-ui/react-focus-scope": "1.0.4", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-popper": "1.1.3", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1", + "@radix-ui/react-use-previous": "1.0.1", + "@radix-ui/react-visually-hidden": "1.0.3", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.5.5" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-separator": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.0.3.tgz", + "integrity": "sha512-itYmTy/kokS21aiV5+Z56MZB54KrhPgn6eHDKkFeOLR34HMN2s8PaN47qZZAGnvupcjxHaFZnW4pQEh0BvvVuw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slider": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.1.2.tgz", + "integrity": "sha512-NKs15MJylfzVsCagVSWKhGGLNR1W9qWs+HtgbmjjVUB3B9+lb3PYoXxVju3kOrpf0VKyVCtZp+iTwVoqpa1Chw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/number": "1.0.1", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-collection": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1", + "@radix-ui/react-use-previous": "1.0.1", + "@radix-ui/react-use-size": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slot": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz", + "integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-switch": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.0.3.tgz", + "integrity": "sha512-mxm87F88HyHztsI7N+ZUmEoARGkC22YVW5CaC+Byc+HRpuvCrOBPTAnXgf+tZ/7i0Sg/eOePGdMhUKhPaQEqow==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-use-previous": "1.0.1", + "@radix-ui/react-use-size": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tabs": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.0.4.tgz", + "integrity": "sha512-egZfYY/+wRNCflXNHx+dePvnz9FbmssDTJBtgRfDY7e8SE5oIo3Py2eCB1ckAbh1Q7cQ/6yJZThJ++sgbxibog==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-direction": "1.0.1", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-roving-focus": "1.0.4", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toast": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toast/-/react-toast-1.1.5.tgz", + "integrity": "sha512-fRLn227WHIBRSzuRzGJ8W+5YALxofH23y0MlPLddaIpLpCDqdE0NZlS2NRQDRiptfxDeeCjgFIpexB1/zkxDlw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-collection": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1", + "@radix-ui/react-visually-hidden": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-toggle": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.0.3.tgz", + "integrity": "sha512-Pkqg3+Bc98ftZGsl60CLANXQBBQ4W3mTFS9EJvNxKMZ7magklKV69/id1mlAlOFDDfHvlCms0fx8fA4CMKDJHg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.0.7.tgz", + "integrity": "sha512-lPh5iKNFVQ/jav/j6ZrWq3blfDJ0OH9R6FlNUHPMqdLuQ9vwDgFsRxvl8b7Asuy5c8xmoojHUxKHQSOAvMHxyw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-popper": "1.1.3", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-visually-hidden": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-callback-ref": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz", + "integrity": "sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-controllable-state": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.0.1.tgz", + "integrity": "sha512-Svl5GY5FQeN758fWKrjM6Qb7asvXeiZltlT4U2gVfl8Gx5UAv2sMR0LWo8yhsIZh2oQ0eFdZ59aoOOMV7b47VA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-callback-ref": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-escape-keydown": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.0.3.tgz", + "integrity": "sha512-vyL82j40hcFicA+M4Ex7hVkB9vHgSse1ZWomAqV2Je3RleKGO5iM8KMOEtfoSB0PnIelMd2lATjTGMYqN5ylTg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-callback-ref": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-layout-effect": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.1.tgz", + "integrity": "sha512-v/5RegiJWYdoCvMnITBkNNx6bCj20fiaJnWtRkU18yITptraXjffz5Qbn05uOiQnOvi+dbkznkoaMltz1GnszQ==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-previous": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-previous/-/react-use-previous-1.0.1.tgz", + "integrity": "sha512-cV5La9DPwiQ7S0gf/0qiD6YgNqM5Fk97Kdrlc5yBcrF3jyEZQwm7vYFqMo4IfeHgJXsRaMvLABFtd0OVEmZhDw==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-rect": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-rect/-/react-use-rect-1.0.1.tgz", + "integrity": "sha512-Cq5DLuSiuYVKNU8orzJMbl15TXilTnJKUCltMVQg53BQOF1/C5toAaGrowkgksdBQ9H+SRL23g0HDmg9tvmxXw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/rect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-size": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-size/-/react-use-size-1.0.1.tgz", + "integrity": "sha512-ibay+VqrgcaI6veAojjofPATwledXiSmX+C0KrBk/xgpX9rBzPV3OsfwlhQdUOFbh+LKQorLYT+xTXW9V8yd0g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-visually-hidden": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.0.3.tgz", + "integrity": "sha512-D4w41yN5YRKtu464TLnByKzMDG/JlMPHtfZgQAu9v6mNakUqGUI9vUrfQKz8NK41VMm/xbZbh76NUTVtIYqOMA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/rect": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/rect/-/rect-1.0.1.tgz", + "integrity": "sha512-fyrgCaedtvMg9NK3en0pnOYJdtfwxUcNolezkNPUsoX57X8oQk+NkqcvzHXD2uKNij6GXmWU9NDru2IWjrO4BQ==", + "dependencies": { + "@babel/runtime": "^7.13.10" + } + }, + "node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.5.0.tgz", + "integrity": "sha512-OINaBGY+Wc++U0rdr7BLuFClxcoWaVW3vQYqmQq6B3bqQ/2olkaoz+K8+af/Mmka/C2yN5j+L9scBkv4BtKsDA==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-android-arm64": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.5.0.tgz", + "integrity": "sha512-UdMf1pOQc4ZmUA/NTmKhgJTBimbSKnhPS2zJqucqFyBRFPnPDtwA8MzrGNTjDeQbIAWfpJVAlxejw+/lQyBK/w==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.5.0.tgz", + "integrity": "sha512-L0/CA5p/idVKI+c9PcAPGorH6CwXn6+J0Ys7Gg1axCbTPgI8MeMlhA6fLM9fK+ssFhqogMHFC8HDvZuetOii7w==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-darwin-x64": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.5.0.tgz", + "integrity": "sha512-QZCbVqU26mNlLn8zi/XDDquNmvcr4ON5FYAHQQsyhrHx8q+sQi/6xduoznYXwk/KmKIXG5dLfR0CvY+NAWpFYQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.5.0.tgz", + "integrity": "sha512-VpSQ+xm93AeV33QbYslgf44wc5eJGYfYitlQzAi3OObu9iwrGXEnmu5S3ilkqE3Pr/FkgOiJKV/2p0ewf4Hrtg==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.5.0.tgz", + "integrity": "sha512-OrEyIfpxSsMal44JpEVx9AEcGpdBQG1ZuWISAanaQTSMeStBW+oHWwOkoqR54bw3x8heP8gBOyoJiGg+fLY8qQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.5.0.tgz", + "integrity": "sha512-1H7wBbQuE6igQdxMSTjtFfD+DGAudcYWhp106z/9zBA8OQhsJRnemO4XGavdzHpGhRtRxbgmUGdO3YQgrWf2RA==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.5.0.tgz", + "integrity": "sha512-FVyFI13tXw5aE65sZdBpNjPVIi4Q5mARnL/39UIkxvSgRAIqCo5sCpCELk0JtXHGee2owZz5aNLbWNfBHzr71Q==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.5.0.tgz", + "integrity": "sha512-eBPYl2sLpH/o8qbSz6vPwWlDyThnQjJfcDOGFbNjmjb44XKC1F5dQfakOsADRVrXCNzM6ZsSIPDG5dc6HHLNFg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.5.0.tgz", + "integrity": "sha512-xaOHIfLOZypoQ5U2I6rEaugS4IYtTgP030xzvrBf5js7p9WI9wik07iHmsKaej8Z83ZDxN5GyypfoyKV5O5TJA==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.5.0.tgz", + "integrity": "sha512-Al6quztQUrHwcOoU2TuFblUQ5L+/AmPBXFR6dUvyo4nRj2yQRK0WIUaGMF/uwKulvRcXkpHe3k9A8Vf93VDktA==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.5.0.tgz", + "integrity": "sha512-8kdW+brNhI/NzJ4fxDufuJUjepzINqJKLGHuxyAtpPG9bMbn8P5mtaCcbOm0EzLJ+atg+kF9dwg8jpclkVqx5w==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@socket.io/component-emitter": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", + "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==" + }, + "node_modules/@swc/core": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.96.tgz", + "integrity": "sha512-zwE3TLgoZwJfQygdv2SdCK9mRLYluwDOM53I+dT6Z5ZvrgVENmY3txvWDvduzkV+/8IuvrRbVezMpxcojadRdQ==", + "dev": true, + "hasInstallScript": true, + "dependencies": { + "@swc/counter": "^0.1.1", + "@swc/types": "^0.1.5" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/swc" + }, + "optionalDependencies": { + "@swc/core-darwin-arm64": "1.3.96", + "@swc/core-darwin-x64": "1.3.96", + "@swc/core-linux-arm-gnueabihf": "1.3.96", + "@swc/core-linux-arm64-gnu": "1.3.96", + "@swc/core-linux-arm64-musl": "1.3.96", + "@swc/core-linux-x64-gnu": "1.3.96", + "@swc/core-linux-x64-musl": "1.3.96", + "@swc/core-win32-arm64-msvc": "1.3.96", + "@swc/core-win32-ia32-msvc": "1.3.96", + "@swc/core-win32-x64-msvc": "1.3.96" + }, + "peerDependencies": { + "@swc/helpers": "^0.5.0" + }, + "peerDependenciesMeta": { + "@swc/helpers": { + "optional": true + } + } + }, + "node_modules/@swc/core-darwin-arm64": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-darwin-arm64/-/core-darwin-arm64-1.3.96.tgz", + "integrity": "sha512-8hzgXYVd85hfPh6mJ9yrG26rhgzCmcLO0h1TIl8U31hwmTbfZLzRitFQ/kqMJNbIBCwmNH1RU2QcJnL3d7f69A==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-darwin-x64": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-darwin-x64/-/core-darwin-x64-1.3.96.tgz", + "integrity": "sha512-mFp9GFfuPg+43vlAdQZl0WZpZSE8sEzqL7sr/7Reul5McUHP0BaLsEzwjvD035ESfkY8GBZdLpMinblIbFNljQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-arm-gnueabihf": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-arm-gnueabihf/-/core-linux-arm-gnueabihf-1.3.96.tgz", + "integrity": "sha512-8UEKkYJP4c8YzYIY/LlbSo8z5Obj4hqcv/fUTHiEePiGsOddgGf7AWjh56u7IoN/0uEmEro59nc1ChFXqXSGyg==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-arm64-gnu": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-arm64-gnu/-/core-linux-arm64-gnu-1.3.96.tgz", + "integrity": "sha512-c/IiJ0s1y3Ymm2BTpyC/xr6gOvoqAVETrivVXHq68xgNms95luSpbYQ28rqaZC8bQC8M5zdXpSc0T8DJu8RJGw==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-arm64-musl": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-arm64-musl/-/core-linux-arm64-musl-1.3.96.tgz", + "integrity": "sha512-i5/UTUwmJLri7zhtF6SAo/4QDQJDH2fhYJaBIUhrICmIkRO/ltURmpejqxsM/ye9Jqv5zG7VszMC0v/GYn/7BQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-x64-gnu": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-x64-gnu/-/core-linux-x64-gnu-1.3.96.tgz", + "integrity": "sha512-USdaZu8lTIkm4Yf9cogct/j5eqtdZqTgcTib4I+NloUW0E/hySou3eSyp3V2UAA1qyuC72ld1otXuyKBna0YKQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-x64-musl": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-x64-musl/-/core-linux-x64-musl-1.3.96.tgz", + "integrity": "sha512-QYErutd+G2SNaCinUVobfL7jWWjGTI0QEoQ6hqTp7PxCJS/dmKmj3C5ZkvxRYcq7XcZt7ovrYCTwPTHzt6lZBg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-win32-arm64-msvc": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-win32-arm64-msvc/-/core-win32-arm64-msvc-1.3.96.tgz", + "integrity": "sha512-hjGvvAduA3Un2cZ9iNP4xvTXOO4jL3G9iakhFsgVhpkU73SGmK7+LN8ZVBEu4oq2SUcHO6caWvnZ881cxGuSpg==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-win32-ia32-msvc": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-win32-ia32-msvc/-/core-win32-ia32-msvc-1.3.96.tgz", + "integrity": "sha512-Far2hVFiwr+7VPCM2GxSmbh3ikTpM3pDombE+d69hkedvYHYZxtTF+2LTKl/sXtpbUnsoq7yV/32c9R/xaaWfw==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-win32-x64-msvc": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-win32-x64-msvc/-/core-win32-x64-msvc-1.3.96.tgz", + "integrity": "sha512-4VbSAniIu0ikLf5mBX81FsljnfqjoVGleEkCQv4+zRlyZtO3FHoDPkeLVoy6WRlj7tyrRcfUJ4mDdPkbfTO14g==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/counter": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@swc/counter/-/counter-0.1.2.tgz", + "integrity": "sha512-9F4ys4C74eSTEUNndnER3VJ15oru2NumfQxS8geE+f3eB5xvfxpWyqE5XlVnxb/R14uoXi6SLbBwwiDSkv+XEw==", + "dev": true + }, + "node_modules/@swc/helpers": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.2.tgz", + "integrity": "sha512-E4KcWTpoLHqwPHLxidpOqQbcrZVgi0rsmmZXUle1jXmJfuIf/UWpczUJ7MZZ5tlxytgJXyp0w4PGkkeLiuIdZw==", + "peer": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/@swc/types": { + "version": "0.1.5", + "resolved": "https://registry.npmjs.org/@swc/types/-/types-0.1.5.tgz", + "integrity": "sha512-myfUej5naTBWnqOCc/MdVOLVjXUXtIA+NpDrDBKJtLLg2shUjBu3cZmB/85RyitKc55+lUUyl7oRfLOvkr2hsw==", + "dev": true + }, + "node_modules/@tanstack/eslint-plugin-query": { + "version": "5.8.4", + "resolved": "https://registry.npmjs.org/@tanstack/eslint-plugin-query/-/eslint-plugin-query-5.8.4.tgz", + "integrity": "sha512-KVgcMc+Bn1qbwkxYVWQoiVSNEIN4IAiLj3cUH/SAHT8m8E59Y97o8ON1syp0Rcw094ItG8pEVZFyQuOaH6PDgQ==", + "dev": true, + "dependencies": { + "@typescript-eslint/utils": "^5.54.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "eslint": "^8.0.0" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/scope-manager": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.62.0.tgz", + "integrity": "sha512-VXuvVvZeQCQb5Zgf4HAxc04q5j+WrNAtNh9OwCsCgpKqESMTu3tF/jhZ3xG6T4NZwWl65Bg8KuS2uEvhSfLl0w==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/visitor-keys": "5.62.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/types": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.62.0.tgz", + "integrity": "sha512-87NVngcbVXUahrRTqIK27gD2t5Cu1yuCXxbLcFtCzZGlfyVWWh8mLHkoxzjsB6DDNnvdL+fW8MiwPEJyGJQDgQ==", + "dev": true, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/typescript-estree": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.62.0.tgz", + "integrity": "sha512-CmcQ6uY7b9y694lKdRB8FEel7JbU/40iSAPomu++SjLMntB+2Leay2LO6i8VnJk58MtE9/nQSFIH6jpyRWyYzA==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/visitor-keys": "5.62.0", + "debug": "^4.3.4", + "globby": "^11.1.0", + "is-glob": "^4.0.3", + "semver": "^7.3.7", + "tsutils": "^3.21.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/utils": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.62.0.tgz", + "integrity": "sha512-n8oxjeb5aIbPFEtmQxQYOLI0i9n5ySBEY/ZEHHZqKQSFnxio1rv6dthascc9dLuwrL0RC5mPCxB7vnAVGAYWAQ==", + "dev": true, + "dependencies": { + "@eslint-community/eslint-utils": "^4.2.0", + "@types/json-schema": "^7.0.9", + "@types/semver": "^7.3.12", + "@typescript-eslint/scope-manager": "5.62.0", + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/typescript-estree": "5.62.0", + "eslint-scope": "^5.1.1", + "semver": "^7.3.7" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/visitor-keys": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.62.0.tgz", + "integrity": "sha512-07ny+LHRzQXepkGg6w0mFY41fVUNBrL2Roj/++7V1txKugfjm/Ci/qSND03r2RhlJhJYMcTn9AhhSSqQp0Ysyw==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "5.62.0", + "eslint-visitor-keys": "^3.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/@tanstack/query-core": { + "version": "5.8.7", + "resolved": "https://registry.npmjs.org/@tanstack/query-core/-/query-core-5.8.7.tgz", + "integrity": "sha512-58xOSkxxZK4SGQ/uzX8MDZHLGZCkxlgkPxnfhxUOL2uchnNHyay2UVcR3mQNMgaMwH1e2l+0n+zfS7+UJ/MAJw==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, + "node_modules/@tanstack/react-query": { + "version": "5.8.7", + "resolved": "https://registry.npmjs.org/@tanstack/react-query/-/react-query-5.8.7.tgz", + "integrity": "sha512-RYSSMmkhbJ7tPkf8w+MSRIXQLoUCm7DRnTLDcdf+uampupnriEsob3fVWTt9oaEj+AJWEKeCErDBdZeNcAzURQ==", + "dependencies": { + "@tanstack/query-core": "5.8.7" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0", + "react-native": "*" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + }, + "react-native": { + "optional": true + } + } + }, + "node_modules/@types/axios": { + "version": "0.14.4", + "resolved": "https://registry.npmjs.org/@types/axios/-/axios-0.14.4.tgz", + "integrity": "sha512-9JgOaunvQdsQ/qW2OPmE5+hCeUB52lQSolecrFrthct55QekhmXEwT203s20RL+UHtCQc15y3VXpby9E7Kkh/g==", + "deprecated": "This is a stub types definition. axios provides its own type definitions, so you do not need this installed.", + "dev": true, + "license": "MIT", + "dependencies": { + "axios": "*" + } + }, + "node_modules/@types/babel__core": { + "version": "7.20.4", + "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.4.tgz", + "integrity": "sha512-mLnSC22IC4vcWiuObSRjrLd9XcBTGf59vUSoq2jkQDJ/QQ8PMI9rSuzE+aEV8karUMbskw07bKYoUJCKTUaygg==", + "dev": true, + "dependencies": { + "@babel/parser": "^7.20.7", + "@babel/types": "^7.20.7", + "@types/babel__generator": "*", + "@types/babel__template": "*", + "@types/babel__traverse": "*" + } + }, + "node_modules/@types/babel__generator": { + "version": "7.6.7", + "resolved": "https://registry.npmjs.org/@types/babel__generator/-/babel__generator-7.6.7.tgz", + "integrity": "sha512-6Sfsq+EaaLrw4RmdFWE9Onp63TOUue71AWb4Gpa6JxzgTYtimbM086WnYTy2U67AofR++QKCo08ZP6pwx8YFHQ==", + "dev": true, + "dependencies": { + "@babel/types": "^7.0.0" + } + }, + "node_modules/@types/babel__template": { + "version": "7.4.4", + "resolved": "https://registry.npmjs.org/@types/babel__template/-/babel__template-7.4.4.tgz", + "integrity": "sha512-h/NUaSyG5EyxBIp8YRxo4RMe2/qQgvyowRwVMzhYhBCONbW8PUsg4lkFMrhgZhUe5z3L3MiLDuvyJ/CaPa2A8A==", + "dev": true, + "dependencies": { + "@babel/parser": "^7.1.0", + "@babel/types": "^7.0.0" + } + }, + "node_modules/@types/babel__traverse": { + "version": "7.20.4", + "resolved": "https://registry.npmjs.org/@types/babel__traverse/-/babel__traverse-7.20.4.tgz", + "integrity": "sha512-mSM/iKUk5fDDrEV/e83qY+Cr3I1+Q3qqTuEn++HAWYjEa1+NxZr6CNrcJGf2ZTnq4HoFGC3zaTPZTobCzCFukA==", + "dev": true, + "dependencies": { + "@babel/types": "^7.20.7" + } + }, + "node_modules/@types/flexsearch": { + "version": "0.7.6", + "resolved": "https://registry.npmjs.org/@types/flexsearch/-/flexsearch-0.7.6.tgz", + "integrity": "sha512-H5IXcRn96/gaDmo+rDl2aJuIJsob8dgOXDqf8K0t8rWZd1AFNaaspmRsElESiU+EWE33qfbFPgI0OC/B1g9FCA==", + "dev": true + }, + "node_modules/@types/js-cookie": { + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@types/js-cookie/-/js-cookie-2.2.7.tgz", + "integrity": "sha512-aLkWa0C0vO5b4Sr798E26QgOkss68Un0bLjs7u9qxzPT5CG+8DuNTffWES58YzJs3hrVAOs1wonycqEBqNJubA==" + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "dev": true + }, + "node_modules/@types/lodash": { + "version": "4.14.201", + "resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.14.201.tgz", + "integrity": "sha512-y9euML0cim1JrykNxADLfaG0FgD1g/yTHwUs/Jg9ZIU7WKj2/4IW9Lbb1WZbvck78W/lfGXFfe+u2EGfIJXdLQ==", + "dev": true + }, + "node_modules/@types/node": { + "version": "20.9.2", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.9.2.tgz", + "integrity": "sha512-WHZXKFCEyIUJzAwh3NyyTHYSR35SevJ6mZ1nWwJafKtiQbqRTIKSRcw3Ma3acqgsent3RRDqeVwpHntMk+9irg==", + "dev": true, + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/@types/prop-types": { + "version": "15.7.10", + "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.10.tgz", + "integrity": "sha512-mxSnDQxPqsZxmeShFH+uwQ4kO4gcJcGahjjMFeLbKE95IAZiiZyiEepGZjtXJ7hN/yfu0bu9xN2ajcU0JcxX6A==", + "devOptional": true + }, + "node_modules/@types/react": { + "version": "18.2.37", + "resolved": "https://registry.npmjs.org/@types/react/-/react-18.2.37.tgz", + "integrity": "sha512-RGAYMi2bhRgEXT3f4B92WTohopH6bIXw05FuGlmJEnv/omEn190+QYEIYxIAuIBdKgboYYdVved2p1AxZVQnaw==", + "devOptional": true, + "dependencies": { + "@types/prop-types": "*", + "@types/scheduler": "*", + "csstype": "^3.0.2" + } + }, + "node_modules/@types/react-dom": { + "version": "18.2.15", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.2.15.tgz", + "integrity": "sha512-HWMdW+7r7MR5+PZqJF6YFNSCtjz1T0dsvo/f1BV6HkV+6erD/nA7wd9NM00KVG83zf2nJ7uATPO9ttdIPvi3gg==", + "devOptional": true, + "dependencies": { + "@types/react": "*" + } + }, + "node_modules/@types/scheduler": { + "version": "0.16.6", + "resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.16.6.tgz", + "integrity": "sha512-Vlktnchmkylvc9SnwwwozTv04L/e1NykF5vgoQ0XTmI8DD+wxfjQuHuvHS3p0r2jz2x2ghPs2h1FVeDirIteWA==", + "devOptional": true + }, + "node_modules/@types/semver": { + "version": "7.5.5", + "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.5.5.tgz", + "integrity": "sha512-+d+WYC1BxJ6yVOgUgzK8gWvp5qF8ssV5r4nsDcZWKRWcDQLQ619tvWAxJQYGgBrO1MnLJC7a5GtiYsAoQ47dJg==", + "dev": true + }, + "node_modules/@typescript-eslint/eslint-plugin": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-6.11.0.tgz", + "integrity": "sha512-uXnpZDc4VRjY4iuypDBKzW1rz9T5YBBK0snMn8MaTSNd2kMlj50LnLBABELjJiOL5YHk7ZD8hbSpI9ubzqYI0w==", + "dev": true, + "dependencies": { + "@eslint-community/regexpp": "^4.5.1", + "@typescript-eslint/scope-manager": "6.11.0", + "@typescript-eslint/type-utils": "6.11.0", + "@typescript-eslint/utils": "6.11.0", + "@typescript-eslint/visitor-keys": "6.11.0", + "debug": "^4.3.4", + "graphemer": "^1.4.0", + "ignore": "^5.2.4", + "natural-compare": "^1.4.0", + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "@typescript-eslint/parser": "^6.0.0 || ^6.0.0-alpha", + "eslint": "^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@typescript-eslint/parser": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-6.11.0.tgz", + "integrity": "sha512-+whEdjk+d5do5nxfxx73oanLL9ghKO3EwM9kBCkUtWMRwWuPaFv9ScuqlYfQ6pAD6ZiJhky7TZ2ZYhrMsfMxVQ==", + "dev": true, + "dependencies": { + "@typescript-eslint/scope-manager": "6.11.0", + "@typescript-eslint/types": "6.11.0", + "@typescript-eslint/typescript-estree": "6.11.0", + "@typescript-eslint/visitor-keys": "6.11.0", + "debug": "^4.3.4" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-6.11.0.tgz", + "integrity": "sha512-0A8KoVvIURG4uhxAdjSaxy8RdRE//HztaZdG8KiHLP8WOXSk0vlF7Pvogv+vlJA5Rnjj/wDcFENvDaHb+gKd1A==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "6.11.0", + "@typescript-eslint/visitor-keys": "6.11.0" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/type-utils": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-6.11.0.tgz", + "integrity": "sha512-nA4IOXwZtqBjIoYrJcYxLRO+F9ri+leVGoJcMW1uqr4r1Hq7vW5cyWrA43lFbpRvQ9XgNrnfLpIkO3i1emDBIA==", + "dev": true, + "dependencies": { + "@typescript-eslint/typescript-estree": "6.11.0", + "@typescript-eslint/utils": "6.11.0", + "debug": "^4.3.4", + "ts-api-utils": "^1.0.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@typescript-eslint/types": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-6.11.0.tgz", + "integrity": "sha512-ZbEzuD4DwEJxwPqhv3QULlRj8KYTAnNsXxmfuUXFCxZmO6CF2gM/y+ugBSAQhrqaJL3M+oe4owdWunaHM6beqA==", + "dev": true, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-6.11.0.tgz", + "integrity": "sha512-Aezzv1o2tWJwvZhedzvD5Yv7+Lpu1by/U1LZ5gLc4tCx8jUmuSCMioPFRjliN/6SJIvY6HpTtJIWubKuYYYesQ==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "6.11.0", + "@typescript-eslint/visitor-keys": "6.11.0", + "debug": "^4.3.4", + "globby": "^11.1.0", + "is-glob": "^4.0.3", + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-6.11.0.tgz", + "integrity": "sha512-p23ibf68fxoZy605dc0dQAEoUsoiNoP3MD9WQGiHLDuTSOuqoTsa4oAy+h3KDkTcxbbfOtUjb9h3Ta0gT4ug2g==", + "dev": true, + "dependencies": { + "@eslint-community/eslint-utils": "^4.4.0", + "@types/json-schema": "^7.0.12", + "@types/semver": "^7.5.0", + "@typescript-eslint/scope-manager": "6.11.0", + "@typescript-eslint/types": "6.11.0", + "@typescript-eslint/typescript-estree": "6.11.0", + "semver": "^7.5.4" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^7.0.0 || ^8.0.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-6.11.0.tgz", + "integrity": "sha512-+SUN/W7WjBr05uRxPggJPSzyB8zUpaYo2hByKasWbqr3PM8AXfZt8UHdNpBS1v9SA62qnSSMF3380SwDqqprgQ==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "6.11.0", + "eslint-visitor-keys": "^3.4.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@uidotdev/usehooks": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/@uidotdev/usehooks/-/usehooks-2.4.1.tgz", + "integrity": "sha512-1I+RwWyS+kdv3Mv0Vmc+p0dPYH0DTRAo04HLyXReYBL9AeseDWUJyi4THuksBJcu9F0Pih69Ak150VDnqbVnXg==", + "engines": { + "node": ">=16" + }, + "peerDependencies": { + "react": ">=18.0.0", + "react-dom": ">=18.0.0" + } + }, + "node_modules/@ungap/structured-clone": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.2.0.tgz", + "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==", + "dev": true + }, + "node_modules/@vitejs/plugin-react": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.2.0.tgz", + "integrity": "sha512-+MHTH/e6H12kRp5HUkzOGqPMksezRMmW+TNzlh/QXfI8rRf6l2Z2yH/v12no1UvTwhZgEDMuQ7g7rrfMseU6FQ==", + "dev": true, + "dependencies": { + "@babel/core": "^7.23.3", + "@babel/plugin-transform-react-jsx-self": "^7.23.3", + "@babel/plugin-transform-react-jsx-source": "^7.23.3", + "@types/babel__core": "^7.20.4", + "react-refresh": "^0.14.0" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "peerDependencies": { + "vite": "^4.2.0 || ^5.0.0" + } + }, + "node_modules/@vitejs/plugin-react-swc": { + "version": "3.5.0", + "resolved": "https://registry.npmjs.org/@vitejs/plugin-react-swc/-/plugin-react-swc-3.5.0.tgz", + "integrity": "sha512-1PrOvAaDpqlCV+Up8RkAh9qaiUjoDUcjtttyhXDKw53XA6Ve16SOp6cCOpRs8Dj8DqUQs6eTW5YkLcLJjrXAig==", + "dev": true, + "dependencies": { + "@swc/core": "^1.3.96" + }, + "peerDependencies": { + "vite": "^4 || ^5" + } + }, + "node_modules/@xobotyi/scrollbar-width": { + "version": "1.9.5", + "resolved": "https://registry.npmjs.org/@xobotyi/scrollbar-width/-/scrollbar-width-1.9.5.tgz", + "integrity": "sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==" + }, + "node_modules/acorn": { + "version": "8.11.2", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.11.2.tgz", + "integrity": "sha512-nc0Axzp/0FILLEVsm4fNwLCwMttvhEI263QtVPQcbpfZZ3ts0hLsZGOpE6czNlid7CJ9MlyH8reXkpsf3YUY4w==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/any-promise": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", + "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==" + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/arg": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/arg/-/arg-5.0.2.tgz", + "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==" + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "dev": true + }, + "node_modules/aria-hidden": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.3.tgz", + "integrity": "sha512-xcLxITLe2HYa1cnYnwCjkOO1PqUHQpozB8x9AR0OgWN2woOBi5kSDVxKfd0b7sb1hw5qFeJhXm9H1nu3xSfLeQ==", + "dependencies": { + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/array-union": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", + "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, + "node_modules/autoprefixer": { + "version": "10.4.16", + "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz", + "integrity": "sha512-7vd3UC6xKp0HLfua5IjZlcXvGAGy7cBAXTg2lyQ/8WpNhd6SiZ8Be+xm3FyBSYJx5GKcpRCzBh7RH4/0dnY+uQ==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/autoprefixer" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "browserslist": "^4.21.10", + "caniuse-lite": "^1.0.30001538", + "fraction.js": "^4.3.6", + "normalize-range": "^0.1.2", + "picocolors": "^1.0.0", + "postcss-value-parser": "^4.2.0" + }, + "bin": { + "autoprefixer": "bin/autoprefixer" + }, + "engines": { + "node": "^10 || ^12 || >=14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/axios": { + "version": "1.11.0", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.11.0.tgz", + "integrity": "sha512-1Lx3WLFQWm3ooKDYZD1eXmoGO9fxYQjrycfHFC8P0sCfQVXyROp0p9PFWBehewBOdCwHc+f/b8I0fMto5eSfwA==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.4", + "proxy-from-env": "^1.1.0" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==" + }, + "node_modules/binary-extensions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", + "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "engines": { + "node": ">=8" + } + }, + "node_modules/brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/braces": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "dependencies": { + "fill-range": "^7.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/browserslist": { + "version": "4.22.1", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.22.1.tgz", + "integrity": "sha512-FEVc202+2iuClEhZhrWy6ZiAcRLvNMyYcxZ8raemul1DYVOVdFsbqckWLdsixQZCpJlwe77Z3UTalE7jsjnKfQ==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "caniuse-lite": "^1.0.30001541", + "electron-to-chromium": "^1.4.535", + "node-releases": "^2.0.13", + "update-browserslist-db": "^1.0.13" + }, + "bin": { + "browserslist": "cli.js" + }, + "engines": { + "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" + } + }, + "node_modules/busboy": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/busboy/-/busboy-1.6.0.tgz", + "integrity": "sha512-8SFQbg/0hQ9xy3UNTB0YEnsNBbWfhf7RtnzpL7TkBiTBRfrQ9Fxcnz7VJsleJpyp6rVLvXiuORqjlHi5q+PYuA==", + "peer": true, + "dependencies": { + "streamsearch": "^1.1.0" + }, + "engines": { + "node": ">=10.16.0" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/camelcase-css": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", + "integrity": "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==", + "engines": { + "node": ">= 6" + } + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001563", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001563.tgz", + "integrity": "sha512-na2WUmOxnwIZtwnFI2CZ/3er0wdNzU7hN+cPYz/z2ajHThnkWjNBOpEPP4n+4r2WPM847JaMotaJE3bnfzjyKw==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ] + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chokidar": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", + "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "funding": [ + { + "type": "individual", + "url": "https://paulmillr.com/funding/" + } + ], + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/chokidar/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/class-variance-authority": { + "version": "0.7.0", + "resolved": "https://registry.npmjs.org/class-variance-authority/-/class-variance-authority-0.7.0.tgz", + "integrity": "sha512-jFI8IQw4hczaL4ALINxqLEXQbWcNjoSkloa4IaufXCJr6QawJyw7tuRysRsrE8w2p/4gGaxKIt/hX3qz/IbD1A==", + "dependencies": { + "clsx": "2.0.0" + }, + "funding": { + "url": "https://joebell.co.uk" + } + }, + "node_modules/client-only": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz", + "integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==", + "peer": true + }, + "node_modules/clsx": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.0.0.tgz", + "integrity": "sha512-rQ1+kcj+ttHG0MKVGBUXwayCCF1oh39BF5COIpRzuCEv8Mwjv0XucrI2ExNTOn9IlLifGClWQcU9BrZORvtw6Q==", + "engines": { + "node": ">=6" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/commander": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", + "integrity": "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==", + "engines": { + "node": ">= 6" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==" + }, + "node_modules/convert-source-map": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz", + "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==", + "dev": true + }, + "node_modules/copy-to-clipboard": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz", + "integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==", + "dependencies": { + "toggle-selection": "^1.0.6" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", + "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "dev": true, + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/css-in-js-utils": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/css-in-js-utils/-/css-in-js-utils-3.1.0.tgz", + "integrity": "sha512-fJAcud6B3rRu+KHYk+Bwf+WFL2MDCJJ1XG9x137tJQ0xYxor7XziQtuGFbWNdqrvF4Tk26O3H73nfVqXt/fW1A==", + "dependencies": { + "hyphenate-style-name": "^1.0.3" + } + }, + "node_modules/css-tree": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-1.1.3.tgz", + "integrity": "sha512-tRpdppF7TRazZrjJ6v3stzv93qxRcSsFmW6cX0Zm2NVKpxE1WV1HblnghVv9TreireHkqI/VDEsfolRF1p6y7Q==", + "dependencies": { + "mdn-data": "2.0.14", + "source-map": "^0.6.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/cssesc": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", + "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", + "bin": { + "cssesc": "bin/cssesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/csstype": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.2.tgz", + "integrity": "sha512-I7K1Uu0MBPzaFKg4nI5Q7Vs2t+3gWWW648spaF+Rg7pI9ds18Ugn+lvg4SHczUdKlHI5LWBXyqfS8+DufyBsgQ==" + }, + "node_modules/debug": { + "version": "4.3.4", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", + "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "dependencies": { + "ms": "2.1.2" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/deep-is": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", + "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", + "dev": true + }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/detect-node-es": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz", + "integrity": "sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==" + }, + "node_modules/didyoumean": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", + "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==" + }, + "node_modules/dir-glob": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", + "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", + "dev": true, + "dependencies": { + "path-type": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/dlv": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", + "integrity": "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==" + }, + "node_modules/doctrine": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", + "integrity": "sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w==", + "dev": true, + "dependencies": { + "esutils": "^2.0.2" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/electron-to-chromium": { + "version": "1.4.588", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.588.tgz", + "integrity": "sha512-soytjxwbgcCu7nh5Pf4S2/4wa6UIu+A3p03U2yVr53qGxi1/VTR3ENI+p50v+UxqqZAfl48j3z55ud7VHIOr9w==", + "dev": true + }, + "node_modules/engine.io-client": { + "version": "6.5.3", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz", + "integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1", + "engine.io-parser": "~5.2.1", + "ws": "~8.11.0", + "xmlhttprequest-ssl": "~2.0.0" + } + }, + "node_modules/engine.io-parser": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.1.tgz", + "integrity": "sha512-9JktcM3u18nU9N2Lz3bWeBgxVgOKpw7yhRaoxQA3FUDZzzw+9WlA6p4G4u0RixNkg14fH7EfEc/RhpurtiROTQ==", + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/error-stack-parser": { + "version": "2.1.4", + "resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.1.4.tgz", + "integrity": "sha512-Sk5V6wVazPhq5MhpO+AUxJn5x7XSXGl1R93Vn7i+zS15KDVxQijejNCrz8340/2bgLBjR9GtEG8ZVKONDjcqGQ==", + "dependencies": { + "stackframe": "^1.3.4" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/esbuild": { + "version": "0.19.6", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.6.tgz", + "integrity": "sha512-Xl7dntjA2OEIvpr9j0DVxxnog2fyTGnyVoQXAMQI6eR3mf9zCQds7VIKUDCotDgE/p4ncTgeRqgX8t5d6oP4Gw==", + "dev": true, + "hasInstallScript": true, + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=12" + }, + "optionalDependencies": { + "@esbuild/android-arm": "0.19.6", + "@esbuild/android-arm64": "0.19.6", + "@esbuild/android-x64": "0.19.6", + "@esbuild/darwin-arm64": "0.19.6", + "@esbuild/darwin-x64": "0.19.6", + "@esbuild/freebsd-arm64": "0.19.6", + "@esbuild/freebsd-x64": "0.19.6", + "@esbuild/linux-arm": "0.19.6", + "@esbuild/linux-arm64": "0.19.6", + "@esbuild/linux-ia32": "0.19.6", + "@esbuild/linux-loong64": "0.19.6", + "@esbuild/linux-mips64el": "0.19.6", + "@esbuild/linux-ppc64": "0.19.6", + "@esbuild/linux-riscv64": "0.19.6", + "@esbuild/linux-s390x": "0.19.6", + "@esbuild/linux-x64": "0.19.6", + "@esbuild/netbsd-x64": "0.19.6", + "@esbuild/openbsd-x64": "0.19.6", + "@esbuild/sunos-x64": "0.19.6", + "@esbuild/win32-arm64": "0.19.6", + "@esbuild/win32-ia32": "0.19.6", + "@esbuild/win32-x64": "0.19.6" + } + }, + "node_modules/escalade": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", + "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.54.0.tgz", + "integrity": "sha512-NY0DfAkM8BIZDVl6PgSa1ttZbx3xHgJzSNJKYcQglem6CppHyMhRIQkBVSSMaSRnLhig3jsDbEzOjwCVt4AmmA==", + "dev": true, + "dependencies": { + "@eslint-community/eslint-utils": "^4.2.0", + "@eslint-community/regexpp": "^4.6.1", + "@eslint/eslintrc": "^2.1.3", + "@eslint/js": "8.54.0", + "@humanwhocodes/config-array": "^0.11.13", + "@humanwhocodes/module-importer": "^1.0.1", + "@nodelib/fs.walk": "^1.2.8", + "@ungap/structured-clone": "^1.2.0", + "ajv": "^6.12.4", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.2", + "debug": "^4.3.2", + "doctrine": "^3.0.0", + "escape-string-regexp": "^4.0.0", + "eslint-scope": "^7.2.2", + "eslint-visitor-keys": "^3.4.3", + "espree": "^9.6.1", + "esquery": "^1.4.2", + "esutils": "^2.0.2", + "fast-deep-equal": "^3.1.3", + "file-entry-cache": "^6.0.1", + "find-up": "^5.0.0", + "glob-parent": "^6.0.2", + "globals": "^13.19.0", + "graphemer": "^1.4.0", + "ignore": "^5.2.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "is-path-inside": "^3.0.3", + "js-yaml": "^4.1.0", + "json-stable-stringify-without-jsonify": "^1.0.1", + "levn": "^0.4.1", + "lodash.merge": "^4.6.2", + "minimatch": "^3.1.2", + "natural-compare": "^1.4.0", + "optionator": "^0.9.3", + "strip-ansi": "^6.0.1", + "text-table": "^0.2.0" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-plugin-react-hooks": { + "version": "4.6.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-4.6.0.tgz", + "integrity": "sha512-oFc7Itz9Qxh2x4gNHStv3BqJq54ExXmfC+a1NjAta66IAN87Wu0R/QArgIS9qKzX3dXKPI9H5crl9QchNMY9+g==", + "dev": true, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "eslint": "^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0" + } + }, + "node_modules/eslint-plugin-react-refresh": { + "version": "0.4.4", + "resolved": "https://registry.npmjs.org/eslint-plugin-react-refresh/-/eslint-plugin-react-refresh-0.4.4.tgz", + "integrity": "sha512-eD83+65e8YPVg6603Om2iCIwcQJf/y7++MWm4tACtEswFLYMwxwVWAfwN+e19f5Ad/FOyyNg9Dfi5lXhH3Y3rA==", + "dev": true, + "peerDependencies": { + "eslint": ">=7" + } + }, + "node_modules/eslint-scope": { + "version": "7.2.2", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", + "integrity": "sha512-dOt21O7lTMhDM+X9mB4GX+DZrZtCUJPL/wlcTqxyrx5IvO0IYtILdtrQGQp+8n5S0gwSVmOf9NQrjMOgfQZlIg==", + "dev": true, + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/espree": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-9.6.1.tgz", + "integrity": "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==", + "dev": true, + "dependencies": { + "acorn": "^8.9.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^3.4.1" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/esquery": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.5.0.tgz", + "integrity": "sha512-YQLXUplAwJgCydQ78IMJywZCceoqk1oH01OERdSAJc/7U2AylwjhSCLDEtqwg811idIS/9fIU5GjG73IgjKMVg==", + "dev": true, + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==" + }, + "node_modules/fast-glob": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.2.tgz", + "integrity": "sha512-oX2ruAFQwf/Orj8m737Y5adxDQO0LAB7/S5MnxCdTNDd4p6BsyIVsv9JQsATbTSq8KHRpLwIHbVlUNatxd+1Ow==", + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.4" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "dev": true + }, + "node_modules/fast-loops": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/fast-loops/-/fast-loops-1.1.3.tgz", + "integrity": "sha512-8EZzEP0eKkEEVX+drtd9mtuQ+/QrlfW/5MlwcwK5Nds6EkZ/tRzEexkzUY2mIssnAyVLT+TKHuRXmFNNXYUd6g==" + }, + "node_modules/fast-shallow-equal": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fast-shallow-equal/-/fast-shallow-equal-1.0.0.tgz", + "integrity": "sha512-HPtaa38cPgWvaCFmRNhlc6NG7pv6NUHqjPgVAkWGoB9mQMwYB27/K0CvOM5Czy+qpT3e8XJ6Q4aPAnzpNpzNaw==" + }, + "node_modules/fastest-stable-stringify": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/fastest-stable-stringify/-/fastest-stable-stringify-2.0.2.tgz", + "integrity": "sha512-bijHueCGd0LqqNK9b5oCMHc0MluJAx0cwqASgbWMvkO01lCYgIhacVRLcaDz3QnyYIRNJRDwMb41VuT6pHJ91Q==" + }, + "node_modules/fastq": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/file-entry-cache": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", + "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", + "dev": true, + "dependencies": { + "flat-cache": "^3.0.4" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + } + }, + "node_modules/fill-range": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/find-up": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "dev": true, + "dependencies": { + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/flat-cache": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.2.0.tgz", + "integrity": "sha512-CYcENa+FtcUKLmhhqyctpclsq7QF38pKjZHsGNiSQF5r4FtoKDWabFDl3hzaEQMvT1LHEysw5twgLvpYYb4vbw==", + "dev": true, + "dependencies": { + "flatted": "^3.2.9", + "keyv": "^4.5.3", + "rimraf": "^3.0.2" + }, + "engines": { + "node": "^10.12.0 || >=12.0.0" + } + }, + "node_modules/flatted": { + "version": "3.2.9", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.2.9.tgz", + "integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==", + "dev": true + }, + "node_modules/follow-redirects": { + "version": "1.15.9", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.9.tgz", + "integrity": "sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/form-data": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fraction.js": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", + "integrity": "sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==", + "dev": true, + "engines": { + "node": "*" + }, + "funding": { + "type": "patreon", + "url": "https://github.com/sponsors/rawify" + } + }, + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==" + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/fuse.js": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/fuse.js/-/fuse.js-7.0.0.tgz", + "integrity": "sha512-14F4hBIxqKvD4Zz/XjDc3y94mNZN6pRv3U13Udo0lNLCWRBUsrMv2xwcF/y/Z5sV6+FQW+/ow68cHpm4sunt8Q==", + "engines": { + "node": ">=10" + } + }, + "node_modules/gensync": { + "version": "1.0.0-beta.2", + "resolved": "https://registry.npmjs.org/gensync/-/gensync-1.0.0-beta.2.tgz", + "integrity": "sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==", + "dev": true, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-nonce": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-nonce/-/get-nonce-1.0.1.tgz", + "integrity": "sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==", + "engines": { + "node": ">=6" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "dev": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/glob-to-regexp": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", + "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", + "peer": true + }, + "node_modules/globals": { + "version": "13.23.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-13.23.0.tgz", + "integrity": "sha512-XAmF0RjlrjY23MA51q3HltdlGxUpXPvg0GioKiD9X6HD28iMjo2dKC8Vqwm7lne4GNr78+RHTfliktR6ZH09wA==", + "dev": true, + "dependencies": { + "type-fest": "^0.20.2" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/globby": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", + "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "dev": true, + "dependencies": { + "array-union": "^2.1.0", + "dir-glob": "^3.0.1", + "fast-glob": "^3.2.9", + "ignore": "^5.2.0", + "merge2": "^1.4.1", + "slash": "^3.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "peer": true + }, + "node_modules/graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", + "dev": true + }, + "node_modules/hamt_plus": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/hamt_plus/-/hamt_plus-1.0.2.tgz", + "integrity": "sha512-t2JXKaehnMb9paaYA7J0BX8QQAY8lwfQ9Gjf4pg/mk4krt+cmwmU652HOoWonf+7+EQV97ARPMhhVgU1ra2GhA==" + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/hyphenate-style-name": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/hyphenate-style-name/-/hyphenate-style-name-1.0.4.tgz", + "integrity": "sha512-ygGZLjmXfPHj+ZWh6LwbC37l43MhfztxetbFCoYTM2VjkIUpeHgSNn7QIyVFj7YQ1Wl9Cbw5sholVJPzWvC2MQ==" + }, + "node_modules/ignore": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.0.tgz", + "integrity": "sha512-g7dmpshy+gD7mh88OC9NwSGTKoc3kyLAZQRU1mt53Aw/vnvfXnbC+F/7F7QoYVKbV+KNvJx8wArewKy1vXMtlg==", + "dev": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/immer": { + "version": "10.0.3", + "resolved": "https://registry.npmjs.org/immer/-/immer-10.0.3.tgz", + "integrity": "sha512-pwupu3eWfouuaowscykeckFmVTpqbzW+rXFCX8rQLkZzM9ftBmU/++Ra+o+L27mz03zJTlyV4UUr+fdKNffo4A==", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/immer" + } + }, + "node_modules/import-fresh": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", + "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "dev": true, + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true, + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "dependencies": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==" + }, + "node_modules/inline-style-prefixer": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/inline-style-prefixer/-/inline-style-prefixer-7.0.0.tgz", + "integrity": "sha512-I7GEdScunP1dQ6IM2mQWh6v0mOYdYmH3Bp31UecKdrcUgcURTcctSe1IECdUznSHKSmsHtjrT3CwCPI1pyxfUQ==", + "dependencies": { + "css-in-js-utils": "^3.1.0", + "fast-loops": "^1.1.3" + } + }, + "node_modules/inter-ui": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/inter-ui/-/inter-ui-4.0.0.tgz", + "integrity": "sha512-/2XKDS/6DJYZzBHRsFVrZ2l0NYhaRZXki3jUnpJu4DAZtsZdDUtOHfG5msc6ifwKg59eJiyqeIYBDNgvxz+7XQ==", + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/invariant": { + "version": "2.2.4", + "resolved": "https://registry.npmjs.org/invariant/-/invariant-2.2.4.tgz", + "integrity": "sha512-phJfQVBuaJM5raOpJjSfkiD6BpbCE4Ns//LaXl6wGYtUBY83nWS6Rf9tXm2e8VaK60JEjYldbPif/A2B1C2gNA==", + "dependencies": { + "loose-envify": "^1.0.0" + } + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-core-module": { + "version": "2.13.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.1.tgz", + "integrity": "sha512-hHrIjvZsftOsvKSn2TRYl63zvxsgE0K+0mYMoH6gD4omR5IWB2KynivBQczo3+wF1cCkjzvptnI9Q0sPU66ilw==", + "dependencies": { + "hasown": "^2.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-path-inside": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-3.0.3.tgz", + "integrity": "sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true + }, + "node_modules/jiti": { + "version": "1.21.0", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.0.tgz", + "integrity": "sha512-gFqAIbuKyyso/3G2qhiO2OM6shY6EPP/R0+mkDbyspxKazh8BXDC5FiFsUjlczgdNz/vfra0da2y+aHrusLG/Q==", + "bin": { + "jiti": "bin/jiti.js" + } + }, + "node_modules/js-cookie": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/js-cookie/-/js-cookie-2.2.1.tgz", + "integrity": "sha512-HvdH2LzI/EAZcUwA8+0nKNtWHqS+ZmijLA30RwZA0bo7ToCckjK5MkGhjED9KoRcXO6BaGI3I9UIzSA1FKFPOQ==" + }, + "node_modules/js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==" + }, + "node_modules/js-yaml": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "dev": true, + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, + "node_modules/jsesc": { + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz", + "integrity": "sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA==", + "dev": true, + "bin": { + "jsesc": "bin/jsesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/json-buffer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "dev": true + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", + "dev": true + }, + "node_modules/json5": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", + "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "dev": true, + "bin": { + "json5": "lib/cli.js" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/keyv": { + "version": "4.5.4", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "dev": true, + "dependencies": { + "json-buffer": "3.0.1" + } + }, + "node_modules/levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/lilconfig": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-2.1.0.tgz", + "integrity": "sha512-utWOt/GHzuUxnLKxB6dk81RoOeoNeHgbrXiuGk4yyF5qlRz+iIVWu56E2fqGHFrXz0QNUhLB/8nKqvRH66JKGQ==", + "engines": { + "node": ">=10" + } + }, + "node_modules/lines-and-columns": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", + "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==" + }, + "node_modules/locate-path": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "dev": true, + "dependencies": { + "p-locate": "^5.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==" + }, + "node_modules/lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true + }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, + "node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dev": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/lucide-react": { + "version": "0.292.0", + "resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.292.0.tgz", + "integrity": "sha512-rRgUkpEHWpa5VCT66YscInCQmQuPCB1RFRzkkxMxg4b+jaL0V12E3riWWR2Sh5OIiUhCwGW/ZExuEO4Az32E6Q==", + "peerDependencies": { + "react": "^16.5.1 || ^17.0.0 || ^18.0.0" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/mdn-data": { + "version": "2.0.14", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.0.14.tgz", + "integrity": "sha512-dn6wd0uw5GsdswPFfsgMp5NSB0/aDe6fK94YJV/AJDYXL6HVLWBsxeq7js7Ad+mU2K9LAlwpk6kN2D5mwCPVow==" + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "engines": { + "node": ">= 8" + } + }, + "node_modules/micromatch": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", + "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "dependencies": { + "braces": "^3.0.2", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/mitt": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mitt/-/mitt-3.0.1.tgz", + "integrity": "sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==" + }, + "node_modules/ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" + }, + "node_modules/mz": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz", + "integrity": "sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==", + "dependencies": { + "any-promise": "^1.0.0", + "object-assign": "^4.0.1", + "thenify-all": "^1.0.0" + } + }, + "node_modules/nano-css": { + "version": "5.4.0", + "resolved": "https://registry.npmjs.org/nano-css/-/nano-css-5.4.0.tgz", + "integrity": "sha512-QIbVsMMMsC+RQKJPxFDM70kf31A/JxNLE0D9tX9nwq4tcigY/vpvOJKphcQo55/RbriTnFSgrGnFhb8Y/6hs5g==", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.4.15", + "css-tree": "^1.1.2", + "csstype": "^3.1.2", + "fastest-stable-stringify": "^2.0.2", + "inline-style-prefixer": "^7.0.0", + "rtl-css-js": "^1.16.1", + "stacktrace-js": "^2.0.2", + "stylis": "^4.3.0" + }, + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, + "node_modules/nanoid": { + "version": "3.3.7", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz", + "integrity": "sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", + "dev": true + }, + "node_modules/next": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/next/-/next-14.0.3.tgz", + "integrity": "sha512-AbYdRNfImBr3XGtvnwOxq8ekVCwbFTv/UJoLwmaX89nk9i051AEY4/HAWzU0YpaTDw8IofUpmuIlvzWF13jxIw==", + "peer": true, + "dependencies": { + "@next/env": "14.0.3", + "@swc/helpers": "0.5.2", + "busboy": "1.6.0", + "caniuse-lite": "^1.0.30001406", + "postcss": "8.4.31", + "styled-jsx": "5.1.1", + "watchpack": "2.4.0" + }, + "bin": { + "next": "dist/bin/next" + }, + "engines": { + "node": ">=18.17.0" + }, + "optionalDependencies": { + "@next/swc-darwin-arm64": "14.0.3", + "@next/swc-darwin-x64": "14.0.3", + "@next/swc-linux-arm64-gnu": "14.0.3", + "@next/swc-linux-arm64-musl": "14.0.3", + "@next/swc-linux-x64-gnu": "14.0.3", + "@next/swc-linux-x64-musl": "14.0.3", + "@next/swc-win32-arm64-msvc": "14.0.3", + "@next/swc-win32-ia32-msvc": "14.0.3", + "@next/swc-win32-x64-msvc": "14.0.3" + }, + "peerDependencies": { + "@opentelemetry/api": "^1.1.0", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "sass": "^1.3.0" + }, + "peerDependenciesMeta": { + "@opentelemetry/api": { + "optional": true + }, + "sass": { + "optional": true + } + } + }, + "node_modules/next-themes": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/next-themes/-/next-themes-0.2.1.tgz", + "integrity": "sha512-B+AKNfYNIzh0vqQQKqQItTS8evEouKD7H5Hj3kmuPERwddR2TxvDSFZuTj6T7Jfn1oyeUyJMydPl1Bkxkh0W7A==", + "peerDependencies": { + "next": "*", + "react": "*", + "react-dom": "*" + } + }, + "node_modules/node-releases": { + "version": "2.0.13", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", + "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "dev": true + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/normalize-range": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/normalize-range/-/normalize-range-0.1.2.tgz", + "integrity": "sha512-bdok/XvKII3nUpklnV6P2hxtMNrCboOjAcyBuQnWEhO665FwrSNRxU+AqpsyvO6LgGYPspN+lu5CLtw4jPRKNA==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-hash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz", + "integrity": "sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==", + "engines": { + "node": ">= 6" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/optionator": { + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.3.tgz", + "integrity": "sha512-JjCoypp+jKn1ttEFExxhetCKeJt9zhAgAve5FXHixTvFDW/5aEktX9bufBKLRRMdU7bNtpLfcGu94B3cdEJgjg==", + "dev": true, + "dependencies": { + "@aashutoshrathi/word-wrap": "^1.2.3", + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "dev": true, + "dependencies": { + "p-limit": "^3.0.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==" + }, + "node_modules/path-type": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/picocolors": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", + "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==" + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pify": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", + "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pirates": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.6.tgz", + "integrity": "sha512-saLsH7WeYYPiD25LDuLRRY/i+6HaPYr6G1OUlN39otzkSTxKnubR9RTxS3/Kk50s1g2JTgFwWQDQyplC5/SHZg==", + "engines": { + "node": ">= 6" + } + }, + "node_modules/postcss": { + "version": "8.4.31", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz", + "integrity": "sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "nanoid": "^3.3.6", + "picocolors": "^1.0.0", + "source-map-js": "^1.0.2" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/postcss-import": { + "version": "15.1.0", + "resolved": "https://registry.npmjs.org/postcss-import/-/postcss-import-15.1.0.tgz", + "integrity": "sha512-hpr+J05B2FVYUAXHeK1YyI267J/dDDhMU6B6civm8hSY1jYJnBXxzKDKDswzJmtLHryrjhnDjqqp/49t8FALew==", + "dependencies": { + "postcss-value-parser": "^4.0.0", + "read-cache": "^1.0.0", + "resolve": "^1.1.7" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "postcss": "^8.0.0" + } + }, + "node_modules/postcss-js": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/postcss-js/-/postcss-js-4.0.1.tgz", + "integrity": "sha512-dDLF8pEO191hJMtlHFPRa8xsizHaM82MLfNkUHdUtVEV3tgTp5oj+8qbEqYM57SLfc74KSbw//4SeJma2LRVIw==", + "dependencies": { + "camelcase-css": "^2.0.1" + }, + "engines": { + "node": "^12 || ^14 || >= 16" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": "^8.4.21" + } + }, + "node_modules/postcss-load-config": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/postcss-load-config/-/postcss-load-config-4.0.2.tgz", + "integrity": "sha512-bSVhyJGL00wMVoPUzAVAnbEoWyqRxkjv64tUl427SKnPrENtq6hJwUojroMz2VB+Q1edmi4IfrAPpami5VVgMQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "lilconfig": "^3.0.0", + "yaml": "^2.3.4" + }, + "engines": { + "node": ">= 14" + }, + "peerDependencies": { + "postcss": ">=8.0.9", + "ts-node": ">=9.0.0" + }, + "peerDependenciesMeta": { + "postcss": { + "optional": true + }, + "ts-node": { + "optional": true + } + } + }, + "node_modules/postcss-load-config/node_modules/lilconfig": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.0.0.tgz", + "integrity": "sha512-K2U4W2Ff5ibV7j7ydLr+zLAkIg5JJ4lPn1Ltsdt+Tz/IjQ8buJ55pZAxoP34lqIiwtF9iAvtLv3JGv7CAyAg+g==", + "engines": { + "node": ">=14" + } + }, + "node_modules/postcss-nested": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/postcss-nested/-/postcss-nested-6.0.1.tgz", + "integrity": "sha512-mEp4xPMi5bSWiMbsgoPfcP74lsWLHkQbZc3sY+jWYd65CUwXrUaTp0fmNpa01ZcETKlIgUdFN/MpS2xZtqL9dQ==", + "dependencies": { + "postcss-selector-parser": "^6.0.11" + }, + "engines": { + "node": ">=12.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": "^8.2.14" + } + }, + "node_modules/postcss-selector-parser": { + "version": "6.0.13", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.13.tgz", + "integrity": "sha512-EaV1Gl4mUEV4ddhDnv/xtj7sxwrwxdetHdWUGnT4VJQf+4d05v6lHYZr8N573k5Z0BViss7BDhfWtKS3+sfAqQ==", + "dependencies": { + "cssesc": "^3.0.0", + "util-deprecate": "^1.0.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/postcss-value-parser": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", + "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==" + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==" + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/react": { + "version": "18.2.0", + "resolved": "https://registry.npmjs.org/react/-/react-18.2.0.tgz", + "integrity": "sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ==", + "dependencies": { + "loose-envify": "^1.1.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-dom": { + "version": "18.2.0", + "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.2.0.tgz", + "integrity": "sha512-6IMTriUmvsjHUjNtEDudZfuDQUoWXVxKHhlEGSk81n4YFS+r/Kl99wXiwlVXtPBtJenozv2P+hxDsw9eA7Xo6g==", + "dependencies": { + "loose-envify": "^1.1.0", + "scheduler": "^0.23.0" + }, + "peerDependencies": { + "react": "^18.2.0" + } + }, + "node_modules/react-hook-form": { + "version": "7.48.2", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.48.2.tgz", + "integrity": "sha512-H0T2InFQb1hX7qKtDIZmvpU1Xfn/bdahWBN1fH19gSe4bBEqTfmlr7H3XWTaVtiK4/tpPaI1F3355GPMZYge+A==", + "engines": { + "node": ">=12.22.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/react-hook-form" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17 || ^18" + } + }, + "node_modules/react-hotkeys-hook": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/react-hotkeys-hook/-/react-hotkeys-hook-4.4.1.tgz", + "integrity": "sha512-sClBMBioFEgFGYLTWWRKvhxcCx1DRznd+wkFHwQZspnRBkHTgruKIHptlK/U/2DPX8BhHoRGzpMVWUXMmdZlmw==", + "peerDependencies": { + "react": ">=16.8.1", + "react-dom": ">=16.8.1" + } + }, + "node_modules/react-photo-album": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/react-photo-album/-/react-photo-album-2.3.0.tgz", + "integrity": "sha512-CU+UMK4ZQHIoPZ672TSst9loKE5bxy6w0+bf7bY4XOw1g1C7+VdDWCW+wD8wPpbg2ve38QBTS73HVe6xYLAQ3w==", + "engines": { + "node": ">=12" + }, + "peerDependencies": { + "react": ">=16.8.0" + } + }, + "node_modules/react-refresh": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.14.0.tgz", + "integrity": "sha512-wViHqhAd8OHeLS/IRMJjTSDHF3U9eWi62F/MledQGPdJGDhodXJ9PBLNGr6WWL7qlH12Mt3TyTpbS+hGXMjCzQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-remove-scroll": { + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.5.5.tgz", + "integrity": "sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==", + "dependencies": { + "react-remove-scroll-bar": "^2.3.3", + "react-style-singleton": "^2.2.1", + "tslib": "^2.1.0", + "use-callback-ref": "^1.3.0", + "use-sidecar": "^1.1.2" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/react-remove-scroll-bar": { + "version": "2.3.4", + "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.4.tgz", + "integrity": "sha512-63C4YQBUt0m6ALadE9XV56hV8BgJWDmmTPY758iIJjfQKt2nYwoUrPk0LXRXcB/yIj82T1/Ixfdpdk68LwIB0A==", + "dependencies": { + "react-style-singleton": "^2.2.1", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/react-style-singleton": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", + "integrity": "sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==", + "dependencies": { + "get-nonce": "^1.0.0", + "invariant": "^2.2.4", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/react-universal-interface": { + "version": "0.6.2", + "resolved": "https://registry.npmjs.org/react-universal-interface/-/react-universal-interface-0.6.2.tgz", + "integrity": "sha512-dg8yXdcQmvgR13RIlZbTRQOoUrDciFVoSBZILwjE2LFISxZZ8loVJKAkuzswl5js8BHda79bIb2b84ehU8IjXw==", + "peerDependencies": { + "react": "*", + "tslib": "*" + } + }, + "node_modules/react-use": { + "version": "17.4.0", + "resolved": "https://registry.npmjs.org/react-use/-/react-use-17.4.0.tgz", + "integrity": "sha512-TgbNTCA33Wl7xzIJegn1HndB4qTS9u03QUwyNycUnXaweZkE4Kq2SB+Yoxx8qbshkZGYBDvUXbXWRUmQDcZZ/Q==", + "dependencies": { + "@types/js-cookie": "^2.2.6", + "@xobotyi/scrollbar-width": "^1.9.5", + "copy-to-clipboard": "^3.3.1", + "fast-deep-equal": "^3.1.3", + "fast-shallow-equal": "^1.0.0", + "js-cookie": "^2.2.1", + "nano-css": "^5.3.1", + "react-universal-interface": "^0.6.2", + "resize-observer-polyfill": "^1.5.1", + "screenfull": "^5.1.0", + "set-harmonic-interval": "^1.0.1", + "throttle-debounce": "^3.0.1", + "ts-easing": "^0.2.0", + "tslib": "^2.1.0" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" + } + }, + "node_modules/react-zoom-pan-pinch": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/react-zoom-pan-pinch/-/react-zoom-pan-pinch-3.3.0.tgz", + "integrity": "sha512-vy1h8aenDzXye+HRqANZaSA8IPHoqOiuDPFBkswoyPUH8uMfsmbeH6gFI4r4BhEJa0xIlcA+FbvhidRWKGUrOg==", + "engines": { + "node": ">=8", + "npm": ">=5" + }, + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, + "node_modules/read-cache": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", + "integrity": "sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==", + "dependencies": { + "pify": "^2.3.0" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/recoil": { + "version": "0.7.7", + "resolved": "https://registry.npmjs.org/recoil/-/recoil-0.7.7.tgz", + "integrity": "sha512-8Og5KPQW9LwC577Vc7Ug2P0vQshkv1y3zG3tSSkWMqkWSwHmE+by06L8JtnGocjW6gcCvfwB3YtrJG6/tWivNQ==", + "dependencies": { + "hamt_plus": "1.0.2" + }, + "peerDependencies": { + "react": ">=16.13.1" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + }, + "react-native": { + "optional": true + } + } + }, + "node_modules/regenerator-runtime": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.0.tgz", + "integrity": "sha512-srw17NI0TUWHuGa5CFGGmhfNIeja30WMBfbslPNhf6JrqQlLN5gcrvig1oqPxiVaXb0oW0XRKtH6Nngs5lKCIA==" + }, + "node_modules/resize-observer-polyfill": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/resize-observer-polyfill/-/resize-observer-polyfill-1.5.1.tgz", + "integrity": "sha512-LwZrotdHOo12nQuZlHEmtuXdqGoOD0OhaxopaNFxWzInpEgaLWoVuAMbTzixuosCx2nEG58ngzW3vxdWoxIgdg==" + }, + "node_modules/resolve": { + "version": "1.22.8", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.8.tgz", + "integrity": "sha512-oKWePCxqpd6FlLvGV1VU0x7bkPmmCNolxzjMf4NczoDnQcIWrAF+cPtZn5i6n+RfD2d9i0tzpKnG6Yk168yIyw==", + "dependencies": { + "is-core-module": "^2.13.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rollup": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.5.0.tgz", + "integrity": "sha512-41xsWhzxqjMDASCxH5ibw1mXk+3c4TNI2UjKbLxe6iEzrSQnqOzmmK8/3mufCPbzHNJ2e04Fc1ddI35hHy+8zg==", + "dev": true, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.5.0", + "@rollup/rollup-android-arm64": "4.5.0", + "@rollup/rollup-darwin-arm64": "4.5.0", + "@rollup/rollup-darwin-x64": "4.5.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.5.0", + "@rollup/rollup-linux-arm64-gnu": "4.5.0", + "@rollup/rollup-linux-arm64-musl": "4.5.0", + "@rollup/rollup-linux-x64-gnu": "4.5.0", + "@rollup/rollup-linux-x64-musl": "4.5.0", + "@rollup/rollup-win32-arm64-msvc": "4.5.0", + "@rollup/rollup-win32-ia32-msvc": "4.5.0", + "@rollup/rollup-win32-x64-msvc": "4.5.0", + "fsevents": "~2.3.2" + } + }, + "node_modules/rtl-css-js": { + "version": "1.16.1", + "resolved": "https://registry.npmjs.org/rtl-css-js/-/rtl-css-js-1.16.1.tgz", + "integrity": "sha512-lRQgou1mu19e+Ya0LsTvKrVJ5TYUbqCVPAiImX3UfLTenarvPUl1QFdvu5Z3PYmHT9RCcwIfbjRQBntExyj3Zg==", + "dependencies": { + "@babel/runtime": "^7.1.2" + } + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/scheduler": { + "version": "0.23.0", + "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.0.tgz", + "integrity": "sha512-CtuThmgHNg7zIZWAXi3AsyIzA3n4xx7aNyjwC2VJldO2LMVDhFK+63xGqq6CsJH4rTAt6/M+N4GhZiDYPx9eUw==", + "dependencies": { + "loose-envify": "^1.1.0" + } + }, + "node_modules/screenfull": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/screenfull/-/screenfull-5.2.0.tgz", + "integrity": "sha512-9BakfsO2aUQN2K9Fdbj87RJIEZ82Q9IGim7FqM5OsebfoFC6ZHXgDq/KvniuLTPdeM8wY2o6Dj3WQ7KeQCj3cA==", + "engines": { + "node": ">=0.10.0" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/semver": { + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", + "dev": true, + "dependencies": { + "lru-cache": "^6.0.0" + }, + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/set-harmonic-interval": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/set-harmonic-interval/-/set-harmonic-interval-1.0.1.tgz", + "integrity": "sha512-AhICkFV84tBP1aWqPwLZqFvAwqEoVA9kxNMniGEUvzOlm4vLmOFLiTT3UZ6bziJTy4bOVpzWGTfSCbmaayGx8g==", + "engines": { + "node": ">=6.9" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/slash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", + "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/socket.io-client": { + "version": "4.7.2", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.7.2.tgz", + "integrity": "sha512-vtA0uD4ibrYD793SOIAwlo8cj6haOeMHrGvwPxJsxH7CeIksqJ+3Zc06RvWTIFgiSqx4A3sOnTXpfAEE2Zyz6w==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.2", + "engine.io-client": "~6.5.2", + "socket.io-parser": "~4.2.4" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-parser": { + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-js": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", + "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/stack-generator": { + "version": "2.0.10", + "resolved": "https://registry.npmjs.org/stack-generator/-/stack-generator-2.0.10.tgz", + "integrity": "sha512-mwnua/hkqM6pF4k8SnmZ2zfETsRUpWXREfA/goT8SLCV4iOFa4bzOX2nDipWAZFPTjLvQB82f5yaodMVhK0yJQ==", + "dependencies": { + "stackframe": "^1.3.4" + } + }, + "node_modules/stackframe": { + "version": "1.3.4", + "resolved": "https://registry.npmjs.org/stackframe/-/stackframe-1.3.4.tgz", + "integrity": "sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==" + }, + "node_modules/stacktrace-gps": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/stacktrace-gps/-/stacktrace-gps-3.1.2.tgz", + "integrity": "sha512-GcUgbO4Jsqqg6RxfyTHFiPxdPqF+3LFmQhm7MgCuYQOYuWyqxo5pwRPz5d/u6/WYJdEnWfK4r+jGbyD8TSggXQ==", + "dependencies": { + "source-map": "0.5.6", + "stackframe": "^1.3.4" + } + }, + "node_modules/stacktrace-gps/node_modules/source-map": { + "version": "0.5.6", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.5.6.tgz", + "integrity": "sha512-MjZkVp0NHr5+TPihLcadqnlVoGIoWo4IBHptutGh9wI3ttUYvCG26HkSuDi+K6lsZ25syXJXcctwgyVCt//xqA==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/stacktrace-js": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/stacktrace-js/-/stacktrace-js-2.0.2.tgz", + "integrity": "sha512-Je5vBeY4S1r/RnLydLl0TBTi3F2qdfWmYsGvtfZgEI+SCprPppaIhQf5nGcal4gI4cGpCV/duLcAzT1np6sQqg==", + "dependencies": { + "error-stack-parser": "^2.0.6", + "stack-generator": "^2.0.5", + "stacktrace-gps": "^3.0.4" + } + }, + "node_modules/streamsearch": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/streamsearch/-/streamsearch-1.1.0.tgz", + "integrity": "sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==", + "peer": true, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/styled-jsx": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/styled-jsx/-/styled-jsx-5.1.1.tgz", + "integrity": "sha512-pW7uC1l4mBZ8ugbiZrcIsiIvVx1UmTfw7UkC3Um2tmfUq9Bhk8IiyEIPl6F8agHgjzku6j0xQEZbfA5uSgSaCw==", + "peer": true, + "dependencies": { + "client-only": "0.0.1" + }, + "engines": { + "node": ">= 12.0.0" + }, + "peerDependencies": { + "react": ">= 16.8.0 || 17.x.x || ^18.0.0-0" + }, + "peerDependenciesMeta": { + "@babel/core": { + "optional": true + }, + "babel-plugin-macros": { + "optional": true + } + } + }, + "node_modules/stylis": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.0.tgz", + "integrity": "sha512-E87pIogpwUsUwXw7dNyU4QDjdgVMy52m+XEOPEKUn161cCzWjjhPSQhByfd1CcNvrOLnXQ6OnnZDwnJrz/Z4YQ==" + }, + "node_modules/sucrase": { + "version": "3.34.0", + "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.34.0.tgz", + "integrity": "sha512-70/LQEZ07TEcxiU2dz51FKaE6hCTWC6vr7FOk3Gr0U60C3shtAN+H+BFr9XlYe5xqf3RA8nrc+VIwzCfnxuXJw==", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.2", + "commander": "^4.0.0", + "glob": "7.1.6", + "lines-and-columns": "^1.1.6", + "mz": "^2.7.0", + "pirates": "^4.0.1", + "ts-interface-checker": "^0.1.9" + }, + "bin": { + "sucrase": "bin/sucrase", + "sucrase-node": "bin/sucrase-node" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/glob": { + "version": "7.1.6", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.1.6.tgz", + "integrity": "sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA==", + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.0.4", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/tailwind-merge": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-2.0.0.tgz", + "integrity": "sha512-WO8qghn9yhsldLSg80au+3/gY9E4hFxIvQ3qOmlpXnqpDKoMruKfi/56BbbMg6fHTQJ9QD3cc79PoWqlaQE4rw==", + "dependencies": { + "@babel/runtime": "^7.23.1" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/dcastil" + } + }, + "node_modules/tailwindcss": { + "version": "3.3.5", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.3.5.tgz", + "integrity": "sha512-5SEZU4J7pxZgSkv7FP1zY8i2TIAOooNZ1e/OGtxIEv6GltpoiXUqWvLy89+a10qYTB1N5Ifkuw9lqQkN9sscvA==", + "dependencies": { + "@alloc/quick-lru": "^5.2.0", + "arg": "^5.0.2", + "chokidar": "^3.5.3", + "didyoumean": "^1.2.2", + "dlv": "^1.1.3", + "fast-glob": "^3.3.0", + "glob-parent": "^6.0.2", + "is-glob": "^4.0.3", + "jiti": "^1.19.1", + "lilconfig": "^2.1.0", + "micromatch": "^4.0.5", + "normalize-path": "^3.0.0", + "object-hash": "^3.0.0", + "picocolors": "^1.0.0", + "postcss": "^8.4.23", + "postcss-import": "^15.1.0", + "postcss-js": "^4.0.1", + "postcss-load-config": "^4.0.1", + "postcss-nested": "^6.0.1", + "postcss-selector-parser": "^6.0.11", + "resolve": "^1.22.2", + "sucrase": "^3.32.0" + }, + "bin": { + "tailwind": "lib/cli.js", + "tailwindcss": "lib/cli.js" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tailwindcss-animate": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/tailwindcss-animate/-/tailwindcss-animate-1.0.7.tgz", + "integrity": "sha512-bl6mpH3T7I3UFxuvDEXLxy/VuFxBk5bbzplh7tXI68mwMokNYd1t9qPBHlnyTwfa4JGC4zP516I1hYYtQ/vspA==", + "peerDependencies": { + "tailwindcss": ">=3.0.0 || insiders" + } + }, + "node_modules/text-table": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", + "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==", + "dev": true + }, + "node_modules/thenify": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/thenify/-/thenify-3.3.1.tgz", + "integrity": "sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==", + "dependencies": { + "any-promise": "^1.0.0" + } + }, + "node_modules/thenify-all": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/thenify-all/-/thenify-all-1.6.0.tgz", + "integrity": "sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==", + "dependencies": { + "thenify": ">= 3.1.0 < 4" + }, + "engines": { + "node": ">=0.8" + } + }, + "node_modules/throttle-debounce": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/throttle-debounce/-/throttle-debounce-3.0.1.tgz", + "integrity": "sha512-dTEWWNu6JmeVXY0ZYoPuH5cRIwc0MeGbJwah9KUNYSJwommQpCzTySTpEe8Gs1J23aeWEuAobe4Ag7EHVt/LOg==", + "engines": { + "node": ">=10" + } + }, + "node_modules/to-fast-properties": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/to-fast-properties/-/to-fast-properties-2.0.0.tgz", + "integrity": "sha512-/OaKK0xYrs3DmxRYqL/yDc+FxFUVYhDlXMhRmv3z915w2HF1tnN1omB354j8VUGO/hbRzyD6Y3sA7v7GS/ceog==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/toggle-selection": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz", + "integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==" + }, + "node_modules/ts-api-utils": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.0.3.tgz", + "integrity": "sha512-wNMeqtMz5NtwpT/UZGY5alT+VoKdSsOOP/kqHFcUW1P/VRhH2wJ48+DN2WwUliNbQ976ETwDL0Ifd2VVvgonvg==", + "dev": true, + "engines": { + "node": ">=16.13.0" + }, + "peerDependencies": { + "typescript": ">=4.2.0" + } + }, + "node_modules/ts-easing": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/ts-easing/-/ts-easing-0.2.0.tgz", + "integrity": "sha512-Z86EW+fFFh/IFB1fqQ3/+7Zpf9t2ebOAxNI/V6Wo7r5gqiqtxmgTlQ1qbqQcjLKYeSHPTsEmvlJUDg/EuL0uHQ==" + }, + "node_modules/ts-interface-checker": { + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", + "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==" + }, + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" + }, + "node_modules/tsutils": { + "version": "3.21.0", + "resolved": "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz", + "integrity": "sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==", + "dev": true, + "dependencies": { + "tslib": "^1.8.1" + }, + "engines": { + "node": ">= 6" + }, + "peerDependencies": { + "typescript": ">=2.8.0 || >= 3.2.0-dev || >= 3.3.0-dev || >= 3.4.0-dev || >= 3.5.0-dev || >= 3.6.0-dev || >= 3.6.0-beta || >= 3.7.0-dev || >= 3.7.0-beta" + } + }, + "node_modules/tsutils/node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", + "dev": true + }, + "node_modules/type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/type-fest": { + "version": "0.20.2", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", + "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/typescript": { + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", + "dev": true, + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "dev": true + }, + "node_modules/update-browserslist-db": { + "version": "1.0.13", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", + "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "escalade": "^3.1.1", + "picocolors": "^1.0.0" + }, + "bin": { + "update-browserslist-db": "cli.js" + }, + "peerDependencies": { + "browserslist": ">= 4.21.0" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/use-callback-ref": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.0.tgz", + "integrity": "sha512-3FT9PRuRdbB9HfXhEq35u4oZkvpJ5kuYbpqhCfmiZyReuRgpnhDlbr2ZEnnuS0RrJAPn6l23xjFg9kpDM+Ms7w==", + "dependencies": { + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/use-sidecar": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", + "integrity": "sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==", + "dependencies": { + "detect-node-es": "^1.1.0", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/use-sync-external-store": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.2.0.tgz", + "integrity": "sha512-eEgnFxGQ1Ife9bzYs6VLi8/4X6CObHMw9Qr9tPY43iKwsPw8xE8+EFsf/2cFZ5S3esXgpWgtSCtLNS41F+sKPA==", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" + }, + "node_modules/vite": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.0.0.tgz", + "integrity": "sha512-ESJVM59mdyGpsiNAeHQOR/0fqNoOyWPYesFto8FFZugfmhdHx8Fzd8sF3Q/xkVhZsyOxHfdM7ieiVAorI9RjFw==", + "dev": true, + "dependencies": { + "esbuild": "^0.19.3", + "postcss": "^8.4.31", + "rollup": "^4.2.0" + }, + "bin": { + "vite": "bin/vite.js" + }, + "engines": { + "node": "^18.0.0 || >=20.0.0" + }, + "funding": { + "url": "https://github.com/vitejs/vite?sponsor=1" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + }, + "peerDependencies": { + "@types/node": "^18.0.0 || >=20.0.0", + "less": "*", + "lightningcss": "^1.21.0", + "sass": "*", + "stylus": "*", + "sugarss": "*", + "terser": "^5.4.0" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + }, + "less": { + "optional": true + }, + "lightningcss": { + "optional": true + }, + "sass": { + "optional": true + }, + "stylus": { + "optional": true + }, + "sugarss": { + "optional": true + }, + "terser": { + "optional": true + } + } + }, + "node_modules/watchpack": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", + "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", + "peer": true, + "dependencies": { + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.1.2" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==" + }, + "node_modules/ws": { + "version": "8.11.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", + "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xmlhttprequest-ssl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz", + "integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "dev": true + }, + "node_modules/yaml": { + "version": "2.3.4", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.3.4.tgz", + "integrity": "sha512-8aAvwVUSHpfEqTQ4w/KMlf3HcRdt50E5ODIQJBw1fQ5RL34xabzxtUlzTXVqc4rkZsPbvrXKWnABCD7kWSmocA==", + "engines": { + "node": ">= 14" + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/zod": { + "version": "3.22.4", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.22.4.tgz", + "integrity": "sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zundo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/zundo/-/zundo-2.0.0.tgz", + "integrity": "sha512-XzKDyunmyxvQHKDjgTmOClOQscJAm5NAa1iEazR0DilvV/uwCjnDwlHJuJ+GmG/oj5RMjzsD0ptghZzjEj1w4g==", + "funding": { + "type": "individual", + "url": "https://github.com/sponsors/charkour" + }, + "peerDependencies": { + "zustand": "^4.3.0" + } + }, + "node_modules/zustand": { + "version": "4.4.6", + "resolved": "https://registry.npmjs.org/zustand/-/zustand-4.4.6.tgz", + "integrity": "sha512-Rb16eW55gqL4W2XZpJh0fnrATxYEG3Apl2gfHTyDSE965x/zxslTikpNch0JgNjJA9zK6gEFW8Fl6d1rTZaqgg==", + "dependencies": { + "use-sync-external-store": "1.2.0" + }, + "engines": { + "node": ">=12.7.0" + }, + "peerDependencies": { + "@types/react": ">=16.8", + "immer": ">=9.0", + "react": ">=16.8" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "immer": { + "optional": true + }, + "react": { + "optional": true + } + } + } + } +} diff --git a/custom-demo/back-end/web_app/package.json b/custom-demo/back-end/web_app/package.json new file mode 100644 index 0000000..845a269 --- /dev/null +++ b/custom-demo/back-end/web_app/package.json @@ -0,0 +1,83 @@ +{ + "name": "web_app", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "tsc && vite build", + "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", + "preview": "vite preview" + }, + "dependencies": { + "@heroicons/react": "^2.0.18", + "@hookform/resolvers": "^3.3.2", + "@radix-ui/react-accordion": "^1.1.2", + "@radix-ui/react-alert-dialog": "^1.0.5", + "@radix-ui/react-context-menu": "^2.1.5", + "@radix-ui/react-dialog": "^1.0.5", + "@radix-ui/react-dropdown-menu": "^2.0.6", + "@radix-ui/react-icons": "^1.3.0", + "@radix-ui/react-label": "^2.0.2", + "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.0.3", + "@radix-ui/react-radio-group": "^1.1.3", + "@radix-ui/react-scroll-area": "^1.0.5", + "@radix-ui/react-select": "^2.0.0", + "@radix-ui/react-separator": "^1.0.3", + "@radix-ui/react-slider": "^1.1.2", + "@radix-ui/react-slot": "^1.0.2", + "@radix-ui/react-switch": "^1.0.3", + "@radix-ui/react-tabs": "^1.0.4", + "@radix-ui/react-toast": "^1.1.5", + "@radix-ui/react-toggle": "^1.0.3", + "@radix-ui/react-tooltip": "^1.0.7", + "@tanstack/react-query": "^5.8.7", + "@uidotdev/usehooks": "^2.4.1", + "axios": "^1.11.0", + "class-variance-authority": "^0.7.0", + "clsx": "^2.0.0", + "fuse.js": "^7.0.0", + "immer": "^10.0.3", + "inter-ui": "^4.0.0", + "lodash": "^4.17.21", + "lucide-react": "^0.292.0", + "mitt": "^3.0.1", + "next-themes": "^0.2.1", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-hook-form": "^7.48.2", + "react-hotkeys-hook": "^4.4.1", + "react-photo-album": "^2.3.0", + "react-use": "^17.4.0", + "react-zoom-pan-pinch": "^3.3.0", + "recoil": "^0.7.7", + "socket.io-client": "^4.7.2", + "tailwind-merge": "^2.0.0", + "tailwindcss-animate": "^1.0.7", + "zod": "^3.22.4", + "zundo": "^2.0.0", + "zustand": "^4.4.6" + }, + "devDependencies": { + "@tanstack/eslint-plugin-query": "^5.8.4", + "@types/axios": "^0.14.4", + "@types/flexsearch": "^0.7.6", + "@types/lodash": "^4.14.201", + "@types/node": "^20.9.2", + "@types/react": "^18.2.37", + "@types/react-dom": "^18.2.15", + "@typescript-eslint/eslint-plugin": "^6.10.0", + "@typescript-eslint/parser": "^6.10.0", + "@vitejs/plugin-react": "^4.2.0", + "@vitejs/plugin-react-swc": "^3.5.0", + "autoprefixer": "^10.4.16", + "eslint": "^8.53.0", + "eslint-plugin-react-hooks": "^4.6.0", + "eslint-plugin-react-refresh": "^0.4.4", + "postcss": "^8.4.31", + "tailwindcss": "^3.3.5", + "typescript": "^5.2.2", + "vite": "^5.0.0" + } +} diff --git a/custom-demo/back-end/web_app/postcss.config.js b/custom-demo/back-end/web_app/postcss.config.js new file mode 100644 index 0000000..2e7af2b --- /dev/null +++ b/custom-demo/back-end/web_app/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/custom-demo/back-end/web_app/src/App.tsx b/custom-demo/back-end/web_app/src/App.tsx new file mode 100644 index 0000000..6650840 --- /dev/null +++ b/custom-demo/back-end/web_app/src/App.tsx @@ -0,0 +1,167 @@ +import { useCallback, useEffect, useRef } from "react" + +import useInputImage from "@/hooks/useInputImage" +import { keepGUIAlive } from "@/lib/utils" +import { getServerConfig } from "@/lib/api" +import Header from "@/components/Header" +import Workspace from "@/components/Workspace" +import FileSelect from "@/components/FileSelect" +import { Toaster } from "./components/ui/toaster" +import { useStore } from "./lib/states" +import { useWindowSize } from "react-use" + +const SUPPORTED_FILE_TYPE = [ + "image/jpeg", + "image/png", + "image/webp", + "image/bmp", + "image/tiff", +] +function Home() { + const [file, updateAppState, setServerConfig, setFile] = useStore((state) => [ + state.file, + state.updateAppState, + state.setServerConfig, + state.setFile, + ]) + + const userInputImage = useInputImage() + + const windowSize = useWindowSize() + + useEffect(() => { + if (userInputImage) { + setFile(userInputImage) + } + }, [userInputImage, setFile]) + + useEffect(() => { + updateAppState({ windowSize }) + }, [windowSize]) + + useEffect(() => { + const fetchServerConfig = async () => { + const serverConfig = await getServerConfig() + setServerConfig(serverConfig) + if (serverConfig.isDesktop) { + // Keeping GUI Window Open + keepGUIAlive() + } + } + fetchServerConfig() + }, []) + + const dragCounter = useRef(0) + + const handleDrag = useCallback((event: any) => { + event.preventDefault() + event.stopPropagation() + }, []) + + const handleDragIn = useCallback((event: any) => { + event.preventDefault() + event.stopPropagation() + dragCounter.current += 1 + }, []) + + const handleDragOut = useCallback((event: any) => { + event.preventDefault() + event.stopPropagation() + dragCounter.current -= 1 + if (dragCounter.current > 0) return + }, []) + + const handleDrop = useCallback((event: any) => { + event.preventDefault() + event.stopPropagation() + if (event.dataTransfer.files && event.dataTransfer.files.length > 0) { + if (event.dataTransfer.files.length > 1) { + // setToastState({ + // open: true, + // desc: "Please drag and drop only one file", + // state: "error", + // duration: 3000, + // }) + } else { + const dragFile = event.dataTransfer.files[0] + const fileType = dragFile.type + if (SUPPORTED_FILE_TYPE.includes(fileType)) { + setFile(dragFile) + } else { + // setToastState({ + // open: true, + // desc: "Please drag and drop an image file", + // state: "error", + // duration: 3000, + // }) + } + } + event.dataTransfer.clearData() + } + }, []) + + const onPaste = useCallback((event: any) => { + // TODO: when sd side panel open, ctrl+v not work + // https://htmldom.dev/paste-an-image-from-the-clipboard/ + if (!event.clipboardData) { + return + } + const clipboardItems = event.clipboardData.items + const items: DataTransferItem[] = [].slice + .call(clipboardItems) + .filter((item: DataTransferItem) => { + // Filter the image items only + return item.type.indexOf("image") !== -1 + }) + + if (items.length === 0) { + return + } + + event.preventDefault() + event.stopPropagation() + + // TODO: add confirm dialog + + const item = items[0] + // Get the blob of image + const blob = item.getAsFile() + if (blob) { + setFile(blob) + } + }, []) + + useEffect(() => { + window.addEventListener("dragenter", handleDragIn) + window.addEventListener("dragleave", handleDragOut) + window.addEventListener("dragover", handleDrag) + window.addEventListener("drop", handleDrop) + window.addEventListener("paste", onPaste) + return function cleanUp() { + window.removeEventListener("dragenter", handleDragIn) + window.removeEventListener("dragleave", handleDragOut) + window.removeEventListener("dragover", handleDrag) + window.removeEventListener("drop", handleDrop) + window.removeEventListener("paste", onPaste) + } + }) + + return ( +
+ +
+ + {!file ? ( + { + setFile(f) + }} + /> + ) : ( + <> + )} +
+ ) +} + +export default Home diff --git a/custom-demo/back-end/web_app/src/assets/kofi_button_black.png b/custom-demo/back-end/web_app/src/assets/kofi_button_black.png new file mode 100644 index 0000000..0baa366 Binary files /dev/null and b/custom-demo/back-end/web_app/src/assets/kofi_button_black.png differ diff --git a/custom-demo/back-end/web_app/src/components/Coffee.tsx b/custom-demo/back-end/web_app/src/components/Coffee.tsx new file mode 100644 index 0000000..c674f54 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/Coffee.tsx @@ -0,0 +1,35 @@ +import { Coffee as CoffeeIcon } from "lucide-react" +import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog" +import { IconButton } from "./ui/button" +import { DialogDescription } from "@radix-ui/react-dialog" +import Kofi from "@/assets/kofi_button_black.png" + +export function Coffee() { + return ( + + + + + + + + Buy me a coffee + + Hi, if you found my project is useful, please conside buy me a coffee + to support my work. Thanks! + +
+ + + +
+
+
+ ) +} + +export default Coffee diff --git a/custom-demo/back-end/web_app/src/components/Cropper.tsx b/custom-demo/back-end/web_app/src/components/Cropper.tsx new file mode 100644 index 0000000..94f5c96 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/Cropper.tsx @@ -0,0 +1,400 @@ +import { useStore } from "@/lib/states" +import { cn } from "@/lib/utils" +import React, { useEffect, useState } from "react" +import { twMerge } from "tailwind-merge" + +const DOC_MOVE_OPTS = { capture: true, passive: false } + +const DRAG_HANDLE_BORDER = 2 + +interface EVData { + initX: number + initY: number + initHeight: number + initWidth: number + startResizeX: number + startResizeY: number + ord: string // top/right/bottom/left +} + +interface Props { + maxHeight: number + maxWidth: number + scale: number + minHeight: number + minWidth: number + show: boolean +} + +const clamp = ( + newPos: number, + newLength: number, + oldPos: number, + oldLength: number, + minLength: number, + maxLength: number +) => { + if (newPos !== oldPos && newLength === oldLength) { + if (newPos < 0) { + return [0, oldLength] + } + if (newPos + newLength > maxLength) { + return [maxLength - oldLength, oldLength] + } + } else { + if (newLength < minLength) { + if (newPos === oldPos) { + return [newPos, minLength] + } + return [newPos + newLength - minLength, minLength] + } + if (newPos < 0) { + return [0, newPos + newLength] + } + if (newPos + newLength > maxLength) { + return [newPos, maxLength - newPos] + } + } + + return [newPos, newLength] +} + +const Cropper = (props: Props) => { + const { minHeight, minWidth, maxHeight, maxWidth, scale, show } = props + + const [ + imageWidth, + imageHeight, + isInpainting, + isSD, + { x, y, width, height }, + setX, + setY, + setWidth, + setHeight, + isResizing, + setIsResizing, + ] = useStore((state) => [ + state.imageWidth, + state.imageHeight, + state.isInpainting, + state.isSD(), + state.cropperState, + state.setCropperX, + state.setCropperY, + state.setCropperWidth, + state.setCropperHeight, + state.isCropperExtenderResizing, + state.setIsCropperExtenderResizing, + ]) + + // const [isResizing, setIsResizing] = useState(false) + const [isMoving, setIsMoving] = useState(false) + + useEffect(() => { + setX(Math.round((maxWidth - 512) / 2)) + setY(Math.round((maxHeight - 512) / 2)) + // TODO: 换了一张较小的图片,cropper 的起始位置和边界要修改 + // TODO: 一开始的 scale 不对 + }, [maxHeight, maxWidth, imageWidth, imageHeight]) + + const [evData, setEVData] = useState({ + initX: 0, + initY: 0, + initHeight: 0, + initWidth: 0, + startResizeX: 0, + startResizeY: 0, + ord: "top", + }) + + const onDragFocus = () => { + // console.log("focus") + } + + const clampLeftRight = (newX: number, newWidth: number) => { + return clamp(newX, newWidth, x, width, minWidth, maxWidth) + } + + const clampTopBottom = (newY: number, newHeight: number) => { + return clamp(newY, newHeight, y, height, minHeight, maxHeight) + } + + const onPointerMove = (e: PointerEvent) => { + if (isInpainting) { + return + } + const curX = e.clientX + const curY = e.clientY + + const offsetY = Math.round((curY - evData.startResizeY) / scale) + const offsetX = Math.round((curX - evData.startResizeX) / scale) + + const moveTop = () => { + const newHeight = evData.initHeight - offsetY + const newY = evData.initY + offsetY + const [clampedY, clampedHeight] = clampTopBottom(newY, newHeight) + setHeight(clampedHeight) + setY(clampedY) + } + + const moveBottom = () => { + const newHeight = evData.initHeight + offsetY + const [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) + setHeight(clampedHeight) + setY(clampedY) + } + + const moveLeft = () => { + const newWidth = evData.initWidth - offsetX + const newX = evData.initX + offsetX + const [clampedX, clampedWidth] = clampLeftRight(newX, newWidth) + setWidth(clampedWidth) + setX(clampedX) + } + + const moveRight = () => { + const newWidth = evData.initWidth + offsetX + const [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) + setWidth(clampedWidth) + setX(clampedX) + } + + if (isResizing) { + switch (evData.ord) { + case "topleft": { + moveTop() + moveLeft() + break + } + case "topright": { + moveTop() + moveRight() + break + } + case "bottomleft": { + moveBottom() + moveLeft() + break + } + case "bottomright": { + moveBottom() + moveRight() + break + } + case "top": { + moveTop() + break + } + case "right": { + moveRight() + break + } + case "bottom": { + moveBottom() + break + } + case "left": { + moveLeft() + break + } + + default: + break + } + } + + if (isMoving) { + const newX = evData.initX + offsetX + const newY = evData.initY + offsetY + const [clampedX, clampedWidth] = clampLeftRight(newX, evData.initWidth) + const [clampedY, clampedHeight] = clampTopBottom(newY, evData.initHeight) + setWidth(clampedWidth) + setHeight(clampedHeight) + setX(clampedX) + setY(clampedY) + } + } + + const onPointerDone = () => { + if (isResizing) { + setIsResizing(false) + } + + if (isMoving) { + setIsMoving(false) + } + } + + useEffect(() => { + if (isResizing || isMoving) { + document.addEventListener("pointermove", onPointerMove, DOC_MOVE_OPTS) + document.addEventListener("pointerup", onPointerDone, DOC_MOVE_OPTS) + document.addEventListener("pointercancel", onPointerDone, DOC_MOVE_OPTS) + return () => { + document.removeEventListener( + "pointermove", + onPointerMove, + DOC_MOVE_OPTS + ) + document.removeEventListener("pointerup", onPointerDone, DOC_MOVE_OPTS) + document.removeEventListener( + "pointercancel", + onPointerDone, + DOC_MOVE_OPTS + ) + } + } + }, [isResizing, isMoving, width, height, evData]) + + const onCropPointerDown = (e: React.PointerEvent) => { + const { ord } = (e.target as HTMLElement).dataset + if (ord) { + setIsResizing(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord, + }) + } + } + + const createDragHandle = (cursor: string, side1: string, side2: string) => { + const sideLength = 12 + const halfSideLength = sideLength / 2 + const draghandleCls = `w-[${sideLength}px] h-[${sideLength}px] z-[4] absolute content-[''] block border-2 border-primary borde pointer-events-auto hover:bg-primary` + + let xTrans = "0" + let yTrans = "0" + + let side2Key = side2 + let side2Val = `${-halfSideLength}px` + if (side2 === "") { + side2Val = "50%" + if (side1 === "left" || side1 === "right") { + side2Key = "top" + yTrans = "-50%" + } else { + side2Key = "left" + xTrans = "-50%" + } + } + + return ( +
+ ) + } + + const createCropSelection = () => { + return ( +
+
+
+
+
+ {createDragHandle("cursor-nw-resize", "top", "left")} + {createDragHandle("cursor-ne-resize", "top", "right")} + {createDragHandle("cursor-sw-resize", "bottom", "left")} + {createDragHandle("cursor-se-resize", "bottom", "right")} + {createDragHandle("cursor-ns-resize", "top", "")} + {createDragHandle("cursor-ns-resize", "bottom", "")} + {createDragHandle("cursor-ew-resize", "left", "")} + {createDragHandle("cursor-ew-resize", "right", "")} +
+ ) + } + + const onInfoBarPointerDown = (e: React.PointerEvent) => { + setIsMoving(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord: "", + }) + } + + const createInfoBar = () => { + return ( +
+ {/* TODO: 移动的时候会显示 brush */} + {width} x {height} +
+ ) + } + + const createBorder = () => { + return ( +
+ ) + } + + if (show === false || !isSD) { + return null + } + + return ( +
+
+ {createBorder()} + {createInfoBar()} + {createCropSelection()} +
+
+ ) +} + +export default Cropper diff --git a/custom-demo/back-end/web_app/src/components/DiffusionProgress.tsx b/custom-demo/back-end/web_app/src/components/DiffusionProgress.tsx new file mode 100644 index 0000000..2e2e1a4 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/DiffusionProgress.tsx @@ -0,0 +1,63 @@ +import * as React from "react" +import io from "socket.io-client" +import { Progress } from "./ui/progress" +import { useStore } from "@/lib/states" + +export const API_ENDPOINT = import.meta.env.DEV + ? import.meta.env.VITE_BACKEND + : "" +const socket = io(API_ENDPOINT) + +const DiffusionProgress = () => { + const [settings, isInpainting, isSD] = useStore((state) => [ + state.settings, + state.isInpainting, + state.isSD(), + ]) + + const [isConnected, setIsConnected] = React.useState(false) + const [step, setStep] = React.useState(0) + + const progress = Math.min(Math.round((step / settings.sdSteps) * 100), 100) + + React.useEffect(() => { + socket.on("connect", () => { + setIsConnected(true) + }) + + socket.on("disconnect", () => { + setIsConnected(false) + }) + + socket.on("diffusion_progress", (data) => { + if (data) { + setStep(data.step + 1) + } + }) + + socket.on("diffusion_finish", () => { + setStep(0) + }) + + return () => { + socket.off("connect") + socket.off("disconnect") + socket.off("diffusion_progress") + socket.off("diffusion_finish") + } + }, []) + + return ( +
+ +
{progress}%
+
+ ) +} + +export default DiffusionProgress diff --git a/custom-demo/back-end/web_app/src/components/Editor.tsx b/custom-demo/back-end/web_app/src/components/Editor.tsx new file mode 100644 index 0000000..5bd4ef7 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/Editor.tsx @@ -0,0 +1,989 @@ +import { SyntheticEvent, useCallback, useEffect, useRef, useState } from "react" +import { CursorArrowRaysIcon } from "@heroicons/react/24/outline" +import { useToast } from "@/components/ui/use-toast" +import { + ReactZoomPanPinchContentRef, + TransformComponent, + TransformWrapper, +} from "react-zoom-pan-pinch" +import { useKeyPressEvent } from "react-use" +import { downloadToOutput, runPlugin } from "@/lib/api" +import { IconButton } from "@/components/ui/button" +import { + askWritePermission, + cn, + copyCanvasImage, + downloadImage, + drawLines, + generateMask, + isMidClick, + isRightClick, + mouseXY, + srcToFile, +} from "@/lib/utils" +import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react" +import { useImage } from "@/hooks/useImage" +import { Slider } from "./ui/slider" +import { PluginName } from "@/lib/types" +import { useStore } from "@/lib/states" +import Cropper from "./Cropper" +import { InteractiveSegPoints } from "./InteractiveSeg" +import useHotKey from "@/hooks/useHotkey" +import Extender from "./Extender" +import { MAX_BRUSH_SIZE, MIN_BRUSH_SIZE } from "@/lib/const" + +const TOOLBAR_HEIGHT = 200 +const COMPARE_SLIDER_DURATION_MS = 300 + +interface EditorProps { + file: File +} + +export default function Editor(props: EditorProps) { + const { file } = props + const { toast } = useToast() + + const [ + disableShortCuts, + windowSize, + isInpainting, + imageWidth, + imageHeight, + settings, + enableAutoSaving, + setImageSize, + setBaseBrushSize, + interactiveSegState, + updateInteractiveSegState, + handleCanvasMouseDown, + handleCanvasMouseMove, + undo, + redo, + undoDisabled, + redoDisabled, + isProcessing, + updateAppState, + runMannually, + runInpainting, + isCropperExtenderResizing, + decreaseBaseBrushSize, + increaseBaseBrushSize, + ] = useStore((state) => [ + state.disableShortCuts, + state.windowSize, + state.isInpainting, + state.imageWidth, + state.imageHeight, + state.settings, + state.serverConfig.enableAutoSaving, + state.setImageSize, + state.setBaseBrushSize, + state.interactiveSegState, + state.updateInteractiveSegState, + state.handleCanvasMouseDown, + state.handleCanvasMouseMove, + state.undo, + state.redo, + state.undoDisabled(), + state.redoDisabled(), + state.getIsProcessing(), + state.updateAppState, + state.runMannually(), + state.runInpainting, + state.isCropperExtenderResizing, + state.decreaseBaseBrushSize, + state.increaseBaseBrushSize, + ]) + const baseBrushSize = useStore((state) => state.editorState.baseBrushSize) + const brushSize = useStore((state) => state.getBrushSize()) + const renders = useStore((state) => state.editorState.renders) + const extraMasks = useStore((state) => state.editorState.extraMasks) + const temporaryMasks = useStore((state) => state.editorState.temporaryMasks) + const lineGroups = useStore((state) => state.editorState.lineGroups) + const curLineGroup = useStore((state) => state.editorState.curLineGroup) + + // Local State + const [showOriginal, setShowOriginal] = useState(false) + const [original, isOriginalLoaded] = useImage(file) + const [context, setContext] = useState() + const [imageContext, setImageContext] = useState() + const [{ x, y }, setCoords] = useState({ x: -1, y: -1 }) + const [showBrush, setShowBrush] = useState(false) + const [showRefBrush, setShowRefBrush] = useState(false) + const [isPanning, setIsPanning] = useState(false) + + const [scale, setScale] = useState(1) + const [panned, setPanned] = useState(false) + const [minScale, setMinScale] = useState(1.0) + const windowCenterX = windowSize.width / 2 + const windowCenterY = windowSize.height / 2 + const viewportRef = useRef(null) + // Indicates that the image has been loaded and is centered on first load + const [initialCentered, setInitialCentered] = useState(false) + + const [isDraging, setIsDraging] = useState(false) + + const [sliderPos, setSliderPos] = useState(0) + const [isChangingBrushSizeByWheel, setIsChangingBrushSizeByWheel] = + useState(false) + + const hadDrawSomething = useCallback(() => { + return curLineGroup.length !== 0 + }, [curLineGroup]) + + useEffect(() => { + if ( + !imageContext || + !isOriginalLoaded || + imageWidth === 0 || + imageHeight === 0 + ) { + return + } + const render = renders.length === 0 ? original : renders[renders.length - 1] + imageContext.canvas.width = imageWidth + imageContext.canvas.height = imageHeight + + imageContext.clearRect( + 0, + 0, + imageContext.canvas.width, + imageContext.canvas.height + ) + imageContext.drawImage(render, 0, 0, imageWidth, imageHeight) + }, [ + renders, + original, + isOriginalLoaded, + imageContext, + imageHeight, + imageWidth, + ]) + + useEffect(() => { + if ( + !context || + !isOriginalLoaded || + imageWidth === 0 || + imageHeight === 0 + ) { + return + } + context.canvas.width = imageWidth + context.canvas.height = imageHeight + context.clearRect(0, 0, context.canvas.width, context.canvas.height) + temporaryMasks.forEach((maskImage) => { + context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) + }) + extraMasks.forEach((maskImage) => { + context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) + }) + + if ( + interactiveSegState.isInteractiveSeg && + interactiveSegState.tmpInteractiveSegMask + ) { + context.drawImage( + interactiveSegState.tmpInteractiveSegMask, + 0, + 0, + imageWidth, + imageHeight + ) + } + drawLines(context, curLineGroup) + }, [ + temporaryMasks, + extraMasks, + isOriginalLoaded, + interactiveSegState, + context, + curLineGroup, + imageHeight, + imageWidth, + ]) + + const getCurrentRender = useCallback(async () => { + let targetFile = file + if (renders.length > 0) { + const lastRender = renders[renders.length - 1] + targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type) + } + return targetFile + }, [file, renders]) + + const hadRunInpainting = () => { + return renders.length !== 0 + } + + const getCurrentWidthHeight = useCallback(() => { + let width = 512 + let height = 512 + if (!isOriginalLoaded) { + return [width, height] + } + if (renders.length === 0) { + width = original.naturalWidth + height = original.naturalHeight + } else if (renders.length !== 0) { + width = renders[renders.length - 1].width + height = renders[renders.length - 1].height + } + + return [width, height] + }, [original, isOriginalLoaded, renders]) + + // Draw once the original image is loaded + useEffect(() => { + if (!isOriginalLoaded) { + return + } + + const [width, height] = getCurrentWidthHeight() + if (width !== imageWidth || height !== imageHeight) { + setImageSize(width, height) + } + + const rW = windowSize.width / width + const rH = (windowSize.height - TOOLBAR_HEIGHT) / height + + let s = 1.0 + if (rW < 1 || rH < 1) { + s = Math.min(rW, rH) + } + setMinScale(s) + setScale(s) + + console.log( + `[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}` + ) + + if (context?.canvas) { + console.log("[on file load] set canvas size") + if (width != context.canvas.width) { + context.canvas.width = width + } + if (height != context.canvas.height) { + context.canvas.height = height + } + } + + if (!initialCentered) { + // 防止每次擦除以后图片 zoom 还原 + viewportRef.current?.centerView(s, 1) + console.log("[on file load] centerView") + setInitialCentered(true) + } + }, [ + viewportRef, + imageHeight, + imageWidth, + original, + isOriginalLoaded, + windowSize, + initialCentered, + getCurrentWidthHeight, + ]) + + useEffect(() => { + console.log("[useEffect] centerView") + // render 改变尺寸以后,undo/redo 重新 center + viewportRef?.current?.centerView(minScale, 1) + }, [imageHeight, imageWidth, viewportRef, minScale]) + + // Zoom reset + const resetZoom = useCallback(() => { + if (!minScale || !windowSize) { + return + } + const viewport = viewportRef.current + if (!viewport) { + return + } + const offsetX = (windowSize.width - imageWidth * minScale) / 2 + const offsetY = (windowSize.height - imageHeight * minScale) / 2 + viewport.setTransform(offsetX, offsetY, minScale, 200, "easeOutQuad") + if (viewport.instance.transformState.scale) { + viewport.instance.transformState.scale = minScale + } + + setScale(minScale) + setPanned(false) + }, [ + viewportRef, + windowSize, + imageHeight, + imageWidth, + windowSize.height, + minScale, + ]) + + useEffect(() => { + window.addEventListener("resize", () => { + resetZoom() + }) + return () => { + window.removeEventListener("resize", () => { + resetZoom() + }) + } + }, [windowSize, resetZoom]) + + const handleEscPressed = () => { + if (isProcessing) { + return + } + + if (isDraging) { + setIsDraging(false) + } else { + resetZoom() + } + } + + useHotKey("Escape", handleEscPressed, [ + isDraging, + isInpainting, + resetZoom, + // drawOnCurrentRender, + ]) + + const onMouseMove = (ev: SyntheticEvent) => { + const mouseEvent = ev.nativeEvent as MouseEvent + setCoords({ x: mouseEvent.pageX, y: mouseEvent.pageY }) + } + + const onMouseDrag = (ev: SyntheticEvent) => { + if (isProcessing) { + return + } + + if (interactiveSegState.isInteractiveSeg) { + return + } + if (isPanning) { + return + } + if (!isDraging) { + return + } + if (curLineGroup.length === 0) { + return + } + + handleCanvasMouseMove(mouseXY(ev)) + } + + const runInteractiveSeg = async (newClicks: number[][]) => { + updateAppState({ isPluginRunning: true }) + const targetFile = await getCurrentRender() + try { + const res = await runPlugin( + true, + PluginName.InteractiveSeg, + targetFile, + undefined, + newClicks + ) + const { blob } = res + const img = new Image() + img.onload = () => { + updateInteractiveSegState({ tmpInteractiveSegMask: img }) + } + img.src = blob + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + } + updateAppState({ isPluginRunning: false }) + } + + const onPointerUp = (ev: SyntheticEvent) => { + if (isMidClick(ev)) { + setIsPanning(false) + return + } + if (!hadDrawSomething()) { + return + } + if (interactiveSegState.isInteractiveSeg) { + return + } + if (isPanning) { + return + } + if (!original.src) { + return + } + const canvas = context?.canvas + if (!canvas) { + return + } + if (isInpainting) { + return + } + if (!isDraging) { + return + } + + if (runMannually) { + setIsDraging(false) + } else { + runInpainting() + } + } + + const onCanvasMouseUp = (ev: SyntheticEvent) => { + if (interactiveSegState.isInteractiveSeg) { + const xy = mouseXY(ev) + const newClicks: number[][] = [...interactiveSegState.clicks] + if (isRightClick(ev)) { + newClicks.push([xy.x, xy.y, 0, newClicks.length]) + } else { + newClicks.push([xy.x, xy.y, 1, newClicks.length]) + } + runInteractiveSeg(newClicks) + updateInteractiveSegState({ clicks: newClicks }) + } + } + + const onMouseDown = (ev: SyntheticEvent) => { + if (isProcessing) { + return + } + if (interactiveSegState.isInteractiveSeg) { + return + } + if (isPanning) { + return + } + if (!isOriginalLoaded) { + return + } + const canvas = context?.canvas + if (!canvas) { + return + } + + if (isRightClick(ev)) { + return + } + + if (isMidClick(ev)) { + setIsPanning(true) + return + } + + setIsDraging(true) + handleCanvasMouseDown(mouseXY(ev)) + } + + const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { + keyboardEvent.preventDefault() + undo() + } + useHotKey("meta+z,ctrl+z", handleUndo) + + const handleRedo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { + keyboardEvent.preventDefault() + redo() + } + useHotKey("shift+ctrl+z,shift+meta+z", handleRedo) + + useKeyPressEvent( + "Tab", + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + if (hadRunInpainting()) { + setShowOriginal(() => { + window.setTimeout(() => { + setSliderPos(100) + }, 10) + return true + }) + } + }, + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + if (hadRunInpainting()) { + window.setTimeout(() => { + setSliderPos(0) + }, 10) + window.setTimeout(() => { + setShowOriginal(false) + }, COMPARE_SLIDER_DURATION_MS) + } + } + ) + + const download = useCallback(async () => { + if (file === undefined) { + return + } + if (enableAutoSaving && renders.length > 0) { + try { + await downloadToOutput( + renders[renders.length - 1], + file.name, + file.type + ) + toast({ + description: "Save image success", + }) + } catch (e: any) { + toast({ + variant: "destructive", + title: "Uh oh! Something went wrong.", + description: e.message ? e.message : e.toString(), + }) + } + return + } + + // TODO: download to output directory + const name = file.name.replace(/(\.[\w\d_-]+)$/i, "_cleanup$1") + const curRender = renders[renders.length - 1] + downloadImage(curRender.currentSrc, name) + if (settings.enableDownloadMask) { + let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1") + maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg") + + const maskCanvas = generateMask(imageWidth, imageHeight, lineGroups) + // Create a link + const aDownloadLink = document.createElement("a") + // Add the name of the file to the link + aDownloadLink.download = maskFileName + // Attach the data to the link + aDownloadLink.href = maskCanvas.toDataURL("image/jpeg") + // Get the code to click the download link + aDownloadLink.click() + } + }, [ + file, + enableAutoSaving, + renders, + settings, + imageHeight, + imageWidth, + lineGroups, + ]) + + useHotKey("meta+s,ctrl+s", download) + + const toggleShowBrush = (newState: boolean) => { + if (newState !== showBrush && !isPanning && !isCropperExtenderResizing) { + setShowBrush(newState) + } + } + + const getCursor = useCallback(() => { + if (isProcessing) { + return "default" + } + if (isPanning) { + return "grab" + } + if (showBrush) { + return "none" + } + return undefined + }, [showBrush, isPanning, isProcessing]) + + useHotKey( + "[", + () => { + decreaseBaseBrushSize() + }, + [decreaseBaseBrushSize] + ) + + useHotKey( + "]", + () => { + increaseBaseBrushSize() + }, + [increaseBaseBrushSize] + ) + + // Manual Inpainting Hotkey + useHotKey( + "shift+r", + () => { + if (runMannually && hadDrawSomething()) { + runInpainting() + } + }, + [runMannually, runInpainting, hadDrawSomething] + ) + + useHotKey( + "ctrl+c,meta+c", + async () => { + const hasPermission = await askWritePermission() + if (hasPermission && renders.length > 0) { + if (context?.canvas) { + await copyCanvasImage(context?.canvas) + toast({ + title: "Copy inpainting result to clipboard", + }) + } + } + }, + [renders, context] + ) + + // Toggle clean/zoom tool on spacebar. + useKeyPressEvent( + " ", + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault() + ev?.stopPropagation() + setShowBrush(false) + setIsPanning(true) + } + }, + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault() + ev?.stopPropagation() + setShowBrush(true) + setIsPanning(false) + } + } + ) + + useKeyPressEvent( + "Alt", + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault() + ev?.stopPropagation() + setIsChangingBrushSizeByWheel(true) + } + }, + (ev) => { + if (!disableShortCuts) { + ev?.preventDefault() + ev?.stopPropagation() + setIsChangingBrushSizeByWheel(false) + } + } + ) + + const getCurScale = (): number => { + let s = minScale + if (viewportRef.current?.instance?.transformState.scale !== undefined) { + s = viewportRef.current?.instance?.transformState.scale + } + return s! + } + + const getBrushStyle = (_x: number, _y: number) => { + const curScale = getCurScale() + return { + width: `${brushSize * curScale}px`, + height: `${brushSize * curScale}px`, + left: `${_x}px`, + top: `${_y}px`, + transform: "translate(-50%, -50%)", + } + } + + const renderBrush = (style: any) => { + return ( +
+ ) + } + + const handleSliderChange = (value: number) => { + setBaseBrushSize(value) + + if (!showRefBrush) { + setShowRefBrush(true) + window.setTimeout(() => { + setShowRefBrush(false) + }, 10000) + } + } + + const renderInteractiveSegCursor = () => { + return ( +
+ +
+ ) + } + + const renderCanvas = () => { + return ( + { + if (r) { + viewportRef.current = r + } + }} + panning={{ disabled: !isPanning, velocityDisabled: true }} + wheel={{ step: 0.05, wheelDisabled: isChangingBrushSizeByWheel }} + centerZoomedOut + alignmentAnimation={{ disabled: true }} + centerOnInit + limitToBounds={false} + doubleClick={{ disabled: true }} + initialScale={minScale} + minScale={minScale * 0.3} + onPanning={() => { + if (!panned) { + setPanned(true) + } + }} + onZoom={(ref) => { + setScale(ref.state.scale) + }} + > + +
+ { + if (r && !imageContext) { + const ctx = r.getContext("2d") + if (ctx) { + setImageContext(ctx) + } + } + }} + /> + { + e.preventDefault() + }} + onMouseOver={() => { + toggleShowBrush(true) + setShowRefBrush(false) + }} + onFocus={() => toggleShowBrush(true)} + onMouseLeave={() => toggleShowBrush(false)} + onMouseDown={onMouseDown} + onMouseUp={onCanvasMouseUp} + onMouseMove={onMouseDrag} + ref={(r) => { + if (r && !context) { + const ctx = r.getContext("2d") + if (ctx) { + setContext(ctx) + } + } + }} + /> +
+ {showOriginal && ( + <> +
+ original + + )} +
+
+ + + + + + {interactiveSegState.isInteractiveSeg ? ( + + ) : ( + <> + )} + + + ) + } + + const handleScroll = (event: React.WheelEvent) => { + // deltaY 是垂直滚动增量,正值表示向下滚动,负值表示向上滚动 + // deltaX 是水平滚动增量,正值表示向右滚动,负值表示向左滚动 + if (!isChangingBrushSizeByWheel) { + return + } + + const { deltaY } = event + // console.log(`水平滚动增量: ${deltaX}, 垂直滚动增量: ${deltaY}`) + if (deltaY > 0) { + increaseBaseBrushSize() + } else if (deltaY < 0) { + decreaseBaseBrushSize() + } + } + + return ( + + ) +} diff --git a/custom-demo/back-end/web_app/src/components/Extender.tsx b/custom-demo/back-end/web_app/src/components/Extender.tsx new file mode 100644 index 0000000..409a855 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/Extender.tsx @@ -0,0 +1,414 @@ +import { useStore } from "@/lib/states" +import { ExtenderDirection } from "@/lib/types" +import { cn } from "@/lib/utils" +import React, { useEffect, useState } from "react" +import { twMerge } from "tailwind-merge" + +const DOC_MOVE_OPTS = { capture: true, passive: false } + +const DRAG_HANDLE_BORDER = 2 + +interface EVData { + initX: number + initY: number + initHeight: number + initWidth: number + startResizeX: number + startResizeY: number + ord: string // top/right/bottom/left +} + +interface Props { + scale: number + minHeight: number + minWidth: number + show: boolean +} + +const clamp = ( + newPos: number, + newLength: number, + oldPos: number, + minLength: number +) => { + if (newLength < minLength) { + if (newPos === oldPos) { + return [newPos, minLength] + } + return [newPos + newLength - minLength, minLength] + } + + return [newPos, newLength] +} + +const Extender = (props: Props) => { + const { minHeight, minWidth, scale, show } = props + + const [ + isInpainting, + imageHeight, + imageWdith, + isSD, + { x, y, width, height }, + setX, + setY, + setWidth, + setHeight, + extenderDirection, + isResizing, + setIsResizing, + ] = useStore((state) => [ + state.isInpainting, + state.imageHeight, + state.imageWidth, + state.isSD(), + state.extenderState, + state.setExtenderX, + state.setExtenderY, + state.setExtenderWidth, + state.setExtenderHeight, + state.settings.extenderDirection, + state.isCropperExtenderResizing, + state.setIsCropperExtenderResizing, + ]) + + const [evData, setEVData] = useState({ + initX: 0, + initY: 0, + initHeight: 0, + initWidth: 0, + startResizeX: 0, + startResizeY: 0, + ord: "top", + }) + + const onDragFocus = () => { + // console.log("focus") + } + + const clampLeftRight = (newX: number, newWidth: number) => { + return clamp(newX, newWidth, x, minWidth) + } + + const clampTopBottom = (newY: number, newHeight: number) => { + return clamp(newY, newHeight, y, minHeight) + } + + const onPointerMove = (e: PointerEvent) => { + if (isInpainting) { + return + } + const curX = e.clientX + const curY = e.clientY + + const offsetY = Math.round((curY - evData.startResizeY) / scale) + const offsetX = Math.round((curX - evData.startResizeX) / scale) + + const moveTop = () => { + const newHeight = evData.initHeight - offsetY + const newY = evData.initY + offsetY + let clampedY = newY + let clampedHeight = newHeight + if (extenderDirection === ExtenderDirection.xy) { + if (clampedY > 0) { + clampedY = 0 + clampedHeight = evData.initHeight - Math.abs(evData.initY) + } + } else { + const clamped = clampTopBottom(newY, newHeight) + clampedY = clamped[0] + clampedHeight = clamped[1] + } + setHeight(clampedHeight) + setY(clampedY) + } + + const moveBottom = () => { + const newHeight = evData.initHeight + offsetY + let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) + if (extenderDirection === ExtenderDirection.xy) { + if (clampedHeight < Math.abs(clampedY) + imageHeight) { + clampedHeight = Math.abs(clampedY) + imageHeight + } + } + setHeight(clampedHeight) + setY(clampedY) + } + + const moveLeft = () => { + const newWidth = evData.initWidth - offsetX + const newX = evData.initX + offsetX + let clampedX = newX + let clampedWidth = newWidth + if (extenderDirection === ExtenderDirection.xy) { + if (clampedX > 0) { + clampedX = 0 + clampedWidth = evData.initWidth - Math.abs(evData.initX) + } + } else { + const clamped = clampLeftRight(newX, newWidth) + clampedX = clamped[0] + clampedWidth = clamped[1] + } + setWidth(clampedWidth) + setX(clampedX) + } + + const moveRight = () => { + const newWidth = evData.initWidth + offsetX + let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) + if (extenderDirection === ExtenderDirection.xy) { + if (clampedWidth < Math.abs(clampedX) + imageWdith) { + clampedWidth = Math.abs(clampedX) + imageWdith + } + } + setWidth(clampedWidth) + setX(clampedX) + } + + if (isResizing) { + switch (evData.ord) { + case "topleft": { + moveTop() + moveLeft() + break + } + case "topright": { + moveTop() + moveRight() + break + } + case "bottomleft": { + moveBottom() + moveLeft() + break + } + case "bottomright": { + moveBottom() + moveRight() + break + } + case "top": { + moveTop() + break + } + case "right": { + moveRight() + break + } + case "bottom": { + moveBottom() + break + } + case "left": { + moveLeft() + break + } + + default: + break + } + } + } + + const onPointerDone = () => { + if (isResizing) { + setIsResizing(false) + } + } + + useEffect(() => { + if (isResizing) { + document.addEventListener("pointermove", onPointerMove, DOC_MOVE_OPTS) + document.addEventListener("pointerup", onPointerDone, DOC_MOVE_OPTS) + document.addEventListener("pointercancel", onPointerDone, DOC_MOVE_OPTS) + return () => { + document.removeEventListener( + "pointermove", + onPointerMove, + DOC_MOVE_OPTS + ) + document.removeEventListener("pointerup", onPointerDone, DOC_MOVE_OPTS) + document.removeEventListener( + "pointercancel", + onPointerDone, + DOC_MOVE_OPTS + ) + } + } + }, [isResizing, width, height, evData]) + + const onCropPointerDown = (e: React.PointerEvent) => { + const { ord } = (e.target as HTMLElement).dataset + if (ord) { + setIsResizing(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord, + }) + } + } + + const createDragHandle = (cursor: string, side1: string, side2: string) => { + const sideLength = 12 + const halfSideLength = sideLength / 2 + const draghandleCls = `w-[${sideLength}px] h-[${sideLength}px] z-[4] absolute content-[''] block border-2 border-primary borde pointer-events-auto hover:bg-primary` + + let xTrans = "0" + let yTrans = "0" + + let side2Key = side2 + let side2Val = `${-halfSideLength}px` + if (side2 === "") { + side2Val = "50%" + if (side1 === "left" || side1 === "right") { + side2Key = "top" + yTrans = "-50%" + } else { + side2Key = "left" + xTrans = "-50%" + } + } + + return ( +
+ ) + } + + const createCropSelection = () => { + return ( +
+ {[ExtenderDirection.y, ExtenderDirection.xy].includes( + extenderDirection + ) ? ( + <> +
+
+ {createDragHandle("cursor-ns-resize", "top", "")} + {createDragHandle("cursor-ns-resize", "bottom", "")} + + ) : ( + <> + )} + + {[ExtenderDirection.x, ExtenderDirection.xy].includes( + extenderDirection + ) ? ( + <> +
+
+ {createDragHandle("cursor-ew-resize", "left", "")} + {createDragHandle("cursor-ew-resize", "right", "")} + + ) : ( + <> + )} + + {extenderDirection === ExtenderDirection.xy ? ( + <> + {createDragHandle("cursor-nw-resize", "top", "left")} + {createDragHandle("cursor-ne-resize", "top", "right")} + {createDragHandle("cursor-sw-resize", "bottom", "left")} + {createDragHandle("cursor-se-resize", "bottom", "right")} + + ) : ( + <> + )} +
+ ) + } + + const onInfoBarPointerDown = (e: React.PointerEvent) => { + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord: "", + }) + } + + const createInfoBar = () => { + return ( +
+ {/* TODO: 移动的时候会显示 brush */} + {width} x {height} +
+ ) + } + + const createBorder = () => { + return ( +
+ ) + } + + if (show === false || !isSD) { + return null + } + + return ( +
+
+ {createBorder()} + {createInfoBar()} + {createCropSelection()} +
+
+ ) +} + +export default Extender diff --git a/custom-demo/back-end/web_app/src/components/FileManager.tsx b/custom-demo/back-end/web_app/src/components/FileManager.tsx new file mode 100644 index 0000000..84e04ed --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/FileManager.tsx @@ -0,0 +1,343 @@ +import { + SyntheticEvent, + useEffect, + useState, + useCallback, + useRef, + FormEvent, +} from "react" +import _ from "lodash" +import PhotoAlbum from "react-photo-album" +import { BarsArrowDownIcon, BarsArrowUpIcon } from "@heroicons/react/24/outline" +import { + MagnifyingGlassIcon, + ViewHorizontalIcon, + ViewGridIcon, +} from "@radix-ui/react-icons" +import { useToggle } from "react-use" +import { useDebounce } from "@uidotdev/usehooks" +import Fuse from "fuse.js" +import { useToast } from "@/components/ui/use-toast" +import { API_ENDPOINT, getMedias } from "@/lib/api" +import { IconButton } from "./ui/button" +import { Input } from "./ui/input" +import { Dialog, DialogContent, DialogTitle } from "./ui/dialog" +import { Tabs, TabsList, TabsTrigger } from "./ui/tabs" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "./ui/select" +import { ScrollArea } from "./ui/scroll-area" +import { DialogTrigger } from "@radix-ui/react-dialog" +import { useStore } from "@/lib/states" +import { Filename, SortBy, SortOrder } from "@/lib/types" +import { FolderClosed } from "lucide-react" +import useHotKey from "@/hooks/useHotkey" + +interface Photo { + src: string + height: number + width: number + name: string +} + +const SORT_BY_NAME = "Name" +const SORT_BY_CREATED_TIME = "Created time" +const SORT_BY_MODIFIED_TIME = "Modified time" + +const IMAGE_TAB = "input" +const OUTPUT_TAB = "output" + +const SortByMap = { + [SortBy.NAME]: SORT_BY_NAME, + [SortBy.CTIME]: SORT_BY_CREATED_TIME, + [SortBy.MTIME]: SORT_BY_MODIFIED_TIME, +} + +interface Props { + onPhotoClick(tab: string, filename: string): void + photoWidth: number +} + +export default function FileManager(props: Props) { + const { onPhotoClick, photoWidth } = props + const [open, toggleOpen] = useToggle(false) + + const [fileManagerState, updateFileManagerState] = useStore((state) => [ + state.fileManagerState, + state.updateFileManagerState, + ]) + + const { toast } = useToast() + const [scrollTop, setScrollTop] = useState(0) + const [closeScrollTop, setCloseScrollTop] = useState(0) + + const ref = useRef(null) + const debouncedSearchText = useDebounce(fileManagerState.searchText, 300) + const [tab, setTab] = useState(IMAGE_TAB) + const [filenames, setFilenames] = useState([]) + const [photos, setPhotos] = useState([]) + const [photoIndex, setPhotoIndex] = useState(0) + + useHotKey("f", () => { + toggleOpen() + }) + + useHotKey( + "left", + () => { + let newIndex = photoIndex + if (photoIndex > 0) { + newIndex = photoIndex - 1 + } + setPhotoIndex(newIndex) + onPhotoClick(tab, photos[newIndex].name) + }, + [photoIndex, photos] + ) + + useHotKey( + "right", + () => { + let newIndex = photoIndex + if (photoIndex < photos.length - 1) { + newIndex = photoIndex + 1 + } + setPhotoIndex(newIndex) + onPhotoClick(tab, photos[newIndex].name) + }, + [photoIndex, photos] + ) + + useEffect(() => { + if (!open) { + setCloseScrollTop(scrollTop) + } + }, [open, scrollTop]) + + const onRefChange = useCallback( + (node: HTMLDivElement) => { + if (node !== null) { + if (open) { + setTimeout(() => { + // TODO: without timeout, scrollTo not work, why? + node.scrollTo({ top: closeScrollTop, left: 0 }) + }, 100) + } + } + }, + [open, closeScrollTop] + ) + + useEffect(() => { + const fetchData = async () => { + try { + const filenames = await getMedias(tab) + setFilenames(filenames) + } catch (e: any) { + toast({ + variant: "destructive", + title: "Uh oh! Something went wrong.", + description: e.message ? e.message : e.toString(), + }) + } + } + fetchData() + }, [tab]) + + useEffect(() => { + if (!open) { + return + } + const fetchData = async () => { + try { + let filteredFilenames = filenames + if (debouncedSearchText) { + const fuse = new Fuse(filteredFilenames, { + keys: ["name"], + }) + const items = fuse.search(debouncedSearchText) + filteredFilenames = items.map( + (item) => filteredFilenames[item.refIndex] + ) + } + + filteredFilenames = _.orderBy( + filteredFilenames, + fileManagerState.sortBy, + fileManagerState.sortOrder + ) + + const newPhotos = filteredFilenames.map((filename: Filename) => { + const width = photoWidth + const height = filename.height * (width / filename.width) + const src = `${API_ENDPOINT}/media_thumbnail_file?tab=${tab}&filename=${encodeURIComponent( + filename.name + )}&width=${Math.ceil(width)}&height=${Math.ceil(height)}` + return { src, height, width, name: filename.name } + }) + setPhotos(newPhotos) + } catch (e: any) { + toast({ + variant: "destructive", + title: "Uh oh! Something went wrong.", + description: e.message ? e.message : e.toString(), + }) + } + } + fetchData() + }, [filenames, debouncedSearchText, fileManagerState, photoWidth, open]) + + const onScroll = (event: SyntheticEvent) => { + setScrollTop(event.currentTarget.scrollTop) + } + + const onClick = ({ index }: { index: number }) => { + toggleOpen() + setPhotoIndex(index) + onPhotoClick(tab, photos[index].name) + } + + const renderTitle = () => { + return ( +
+
{`Images (${photos.length})`}
+
+ { + updateFileManagerState({ layout: "rows" }) + }} + > + + + { + updateFileManagerState({ layout: "masonry" }) + }} + > + + +
+
+ ) + } + + return ( + + + + + + + + {renderTitle()} +
+
+ + ) => { + evt.preventDefault() + evt.stopPropagation() + const target = evt.target as HTMLInputElement + updateFileManagerState({ searchText: target.value }) + }} + placeholder="Search by file name" + /> +
+ + setTab(val)}> + + Image Directory + Output Directory + + + +
+
+ + + {fileManagerState.sortOrder === SortOrder.DESCENDING ? ( + { + updateFileManagerState({ sortOrder: SortOrder.ASCENDING }) + }} + > + + + ) : ( + { + updateFileManagerState({ sortOrder: SortOrder.DESCENDING }) + }} + > + + + )} +
+
+
+ + + + +
+
+ ) +} diff --git a/custom-demo/back-end/web_app/src/components/FileSelect.tsx b/custom-demo/back-end/web_app/src/components/FileSelect.tsx new file mode 100644 index 0000000..c75da22 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/FileSelect.tsx @@ -0,0 +1,71 @@ +import { useState } from "react" +import useResolution from "@/hooks/useResolution" + +type FileSelectProps = { + onSelection: (file: File) => void +} + +export default function FileSelect(props: FileSelectProps) { + const { onSelection } = props + + const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`) + + const resolution = useResolution() + + function onFileSelected(file: File) { + if (!file) { + return + } + // Skip non-image files + const isImage = file.type.match("image.*") + if (!isImage) { + return + } + try { + // Check if file is larger than 20mb + if (file.size > 20 * 1024 * 1024) { + throw new Error("file too large") + } + onSelection(file) + } catch (e) { + // eslint-disable-next-line + alert(`error: ${(e as any).message}`) + } + } + + return ( +
+ +
+ ) +} diff --git a/custom-demo/back-end/web_app/src/components/Header.tsx b/custom-demo/back-end/web_app/src/components/Header.tsx new file mode 100644 index 0000000..05acce6 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/Header.tsx @@ -0,0 +1,198 @@ +import { PlayIcon } from "@radix-ui/react-icons" +import { useState } from "react" +import { IconButton, ImageUploadButton } from "@/components/ui/button" +import Shortcuts from "@/components/Shortcuts" +import { useImage } from "@/hooks/useImage" + +import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover" +import PromptInput from "./PromptInput" +import { RotateCw, Image, Upload } from "lucide-react" +import FileManager from "./FileManager" +import { getMediaFile } from "@/lib/api" +import { useStore } from "@/lib/states" +import SettingsDialog from "./Settings" +import { cn, fileToImage } from "@/lib/utils" +import Coffee from "./Coffee" +import { useToast } from "./ui/use-toast" + +const Header = () => { + const [ + file, + customMask, + isInpainting, + serverConfig, + runMannually, + enableUploadMask, + model, + setFile, + setCustomFile, + runInpainting, + showPrevMask, + hidePrevMask, + imageHeight, + imageWidth, + ] = useStore((state) => [ + state.file, + state.customMask, + state.isInpainting, + state.serverConfig, + state.runMannually(), + state.settings.enableUploadMask, + state.settings.model, + state.setFile, + state.setCustomFile, + state.runInpainting, + state.showPrevMask, + state.hidePrevMask, + state.imageHeight, + state.imageWidth, + ]) + + const { toast } = useToast() + const [maskImage, maskImageLoaded] = useImage(customMask) + const [openMaskPopover, setOpenMaskPopover] = useState(false) + + const handleRerunLastMask = () => { + runInpainting() + } + + const onRerunMouseEnter = () => { + showPrevMask() + } + + const onRerunMouseLeave = () => { + hidePrevMask() + } + + return ( +
+
+ {serverConfig.enableFileManager ? ( + { + try { + const newFile = await getMediaFile(tab, filename) + setFile(newFile) + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + return + } + }} + /> + ) : ( + <> + )} + + { + setFile(file) + }} + > + + + +
+ { + let newCustomMask: HTMLImageElement | null = null + try { + newCustomMask = await fileToImage(file) + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + return + } + if ( + newCustomMask.naturalHeight !== imageHeight || + newCustomMask.naturalWidth !== imageWidth + ) { + toast({ + variant: "destructive", + description: `The size of the mask must same as image: ${imageWidth}x${imageHeight}`, + }) + return + } + + setCustomFile(file) + if (!runMannually) { + runInpainting() + } + }} + > + + + + {customMask ? ( + + setOpenMaskPopover(true)} + onMouseLeave={() => setOpenMaskPopover(false)} + style={{ + visibility: customMask ? "visible" : "hidden", + outline: "none", + }} + onClick={() => { + if (customMask) { + } + }} + > + + + + + + {maskImageLoaded ? ( + Custom mask + ) : ( + <> + )} + + + ) : ( + <> + )} +
+ + {file && !model.need_prompt ? ( + + + + ) : ( + <> + )} +
+ + {model.need_prompt ? : <>} + +
+ + + {serverConfig.disableModelSwitch ? <> : } +
+
+ ) +} + +export default Header diff --git a/custom-demo/back-end/web_app/src/components/ImageSize.tsx b/custom-demo/back-end/web_app/src/components/ImageSize.tsx new file mode 100644 index 0000000..17314d4 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/ImageSize.tsx @@ -0,0 +1,20 @@ +import { useStore } from "@/lib/states" + +const ImageSize = () => { + const [imageWidth, imageHeight] = useStore((state) => [ + state.imageWidth, + state.imageHeight, + ]) + + if (!imageWidth || !imageHeight) { + return null + } + + return ( +
+ {imageWidth}x{imageHeight} +
+ ) +} + +export default ImageSize diff --git a/custom-demo/back-end/web_app/src/components/InteractiveSeg.tsx b/custom-demo/back-end/web_app/src/components/InteractiveSeg.tsx new file mode 100644 index 0000000..d540708 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/InteractiveSeg.tsx @@ -0,0 +1,130 @@ +import { useStore } from "@/lib/states" +import { Button } from "./ui/button" +import { Dialog, DialogContent, DialogTitle } from "./ui/dialog" + +interface InteractiveSegReplaceModal { + show: boolean + onClose: () => void + onCleanClick: () => void + onReplaceClick: () => void +} + +const InteractiveSegReplaceModal = (props: InteractiveSegReplaceModal) => { + const { show, onClose, onCleanClick, onReplaceClick } = props + + const onOpenChange = (open: boolean) => { + if (!open) { + onClose() + } + } + + return ( + + + Do you want to remove it or create a new one? +
+ + +
+
+
+ ) +} + +const InteractiveSegConfirmActions = () => { + const [ + interactiveSegState, + resetInteractiveSegState, + handleInteractiveSegAccept, + ] = useStore((state) => [ + state.interactiveSegState, + state.resetInteractiveSegState, + state.handleInteractiveSegAccept, + ]) + + if (!interactiveSegState.isInteractiveSeg) { + return null + } + + return ( +
+ + +
+ ) +} + +interface ItemProps { + x: number + y: number + positive: boolean +} + +const Item = (props: ItemProps) => { + const { x, y, positive } = props + const name = positive + ? "bg-[rgba(21,_215,_121,_0.936)] outline-[rgba(98,255,179,0.31)]" + : "bg-[rgba(237,_49,_55,_0.942)] outline-[rgba(255,89,95,0.31)]" + return ( +
+ ) +} + +const InteractiveSegPoints = () => { + const clicks = useStore((state) => state.interactiveSegState.clicks) + + return ( +
+ {clicks.map((click) => { + return ( + + ) + })} +
+ ) +} + +const InteractiveSeg = () => { + return ( +
+ + {/* */} +
+ ) +} + +export { InteractiveSeg, InteractiveSegPoints } diff --git a/custom-demo/back-end/web_app/src/components/Plugins.tsx b/custom-demo/back-end/web_app/src/components/Plugins.tsx new file mode 100644 index 0000000..5c70984 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/Plugins.tsx @@ -0,0 +1,202 @@ +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from "./ui/dropdown-menu" +import { Button } from "./ui/button" +import { + Blocks, + Fullscreen, + MousePointerClick, + Slice, + Smile, +} from "lucide-react" +import { useStore } from "@/lib/states" +import { PluginInfo } from "@/lib/types" + +export enum PluginName { + RemoveBG = "RemoveBG", + AnimeSeg = "AnimeSeg", + RealESRGAN = "RealESRGAN", + GFPGAN = "GFPGAN", + RestoreFormer = "RestoreFormer", + InteractiveSeg = "InteractiveSeg", +} + +// TODO: get plugin config from server and using form-render?? +const pluginMap = { + [PluginName.RemoveBG]: { + IconClass: Slice, + showName: "RemoveBG", + }, + [PluginName.AnimeSeg]: { + IconClass: Slice, + showName: "Anime Segmentation", + }, + [PluginName.RealESRGAN]: { + IconClass: Fullscreen, + showName: "RealESRGAN", + }, + [PluginName.GFPGAN]: { + IconClass: Smile, + showName: "GFPGAN", + }, + [PluginName.RestoreFormer]: { + IconClass: Smile, + showName: "RestoreFormer", + }, + [PluginName.InteractiveSeg]: { + IconClass: MousePointerClick, + showName: "Interactive Segmentation", + }, +} + +const Plugins = () => { + const [ + file, + plugins, + isPluginRunning, + updateInteractiveSegState, + runRenderablePlugin, + ] = useStore((state) => [ + state.file, + state.serverConfig.plugins, + state.isPluginRunning, + state.updateInteractiveSegState, + state.runRenderablePlugin, + ]) + const disabled = !file + + if (plugins.length === 0) { + return null + } + + const onPluginClick = (genMask: boolean, pluginName: string) => { + if (pluginName === PluginName.InteractiveSeg) { + updateInteractiveSegState({ isInteractiveSeg: true }) + } else { + runRenderablePlugin(genMask, pluginName) + } + } + + const renderRealESRGANPlugin = () => { + return ( + + +
+ + RealESRGAN +
+
+ + + runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 2 }) + } + > + upscale 2x + + + runRenderablePlugin(false, PluginName.RealESRGAN, { upscale: 4 }) + } + > + upscale 4x + + +
+ ) + } + + const renderGenImageAndMaskPlugin = (plugin: PluginInfo) => { + const { IconClass, showName } = pluginMap[plugin.name as PluginName] + return ( + + +
+ + {showName} +
+
+ + onPluginClick(false, plugin.name)}> + Remove Background + + onPluginClick(true, plugin.name)}> + Generate Mask + + +
+ ) + } + + const renderPlugins = () => { + return plugins.map((plugin: PluginInfo) => { + const { IconClass, showName } = pluginMap[plugin.name as PluginName] + if (plugin.name === PluginName.RealESRGAN) { + return renderRealESRGANPlugin() + } + if ( + plugin.name === PluginName.RemoveBG || + plugin.name === PluginName.AnimeSeg + ) { + return renderGenImageAndMaskPlugin(plugin) + } + return ( + onPluginClick(false, plugin.name)} + disabled={disabled} + > +
+ + {showName} +
+
+ ) + }) + } + + return ( + + + + + + {renderPlugins()} + + + ) +} + +export default Plugins diff --git a/custom-demo/back-end/web_app/src/components/PromptInput.tsx b/custom-demo/back-end/web_app/src/components/PromptInput.tsx new file mode 100644 index 0000000..abd0891 --- /dev/null +++ b/custom-demo/back-end/web_app/src/components/PromptInput.tsx @@ -0,0 +1,94 @@ +import React, { FormEvent, useRef } from "react" +import { Button } from "./ui/button" +import { useStore } from "@/lib/states" +import { useClickAway, useToggle } from "react-use" +import { Textarea } from "./ui/textarea" +import { cn } from "@/lib/utils" + +const PromptInput = () => { + const [ + isProcessing, + prompt, + updateSettings, + runInpainting, + showPrevMask, + hidePrevMask, + ] = useStore((state) => [ + state.getIsProcessing(), + state.settings.prompt, + state.updateSettings, + state.runInpainting, + state.showPrevMask, + state.hidePrevMask, + ]) + + const [showScroll, toggleShowScroll] = useToggle(false) + + const ref = useRef(null) + useClickAway(ref, () => { + if (ref?.current) { + const input = ref.current as HTMLTextAreaElement + input.blur() + } + }) + + const handleOnInput = (evt: FormEvent) => { + evt.preventDefault() + evt.stopPropagation() + const target = evt.target as HTMLTextAreaElement + updateSettings({ prompt: target.value }) + } + + const handleRepaintClick = () => { + if (!isProcessing) { + runInpainting() + } + } + + const onKeyUp = (e: React.KeyboardEvent) => { + if (e.key === "Enter" && e.ctrlKey && prompt.length !== 0) { + handleRepaintClick() + } + } + + const onMouseEnter = () => { + showPrevMask() + } + + const onMouseLeave = () => { + hidePrevMask() + } + + return ( +
+
+