add remove bg model selection
This commit is contained in:
		
							parent
							
								
									cf9ceea4e6
								
							
						
					
					
						commit
						8060e16c70
					
				| 
						 | 
				
			
			@ -42,10 +42,10 @@ from iopaint.helper import (
 | 
			
		|||
    adjust_mask,
 | 
			
		||||
)
 | 
			
		||||
from iopaint.model.utils import torch_gc
 | 
			
		||||
from iopaint.model_info import ModelInfo
 | 
			
		||||
from iopaint.model_manager import ModelManager
 | 
			
		||||
from iopaint.plugins import build_plugins
 | 
			
		||||
from iopaint.plugins.base_plugin import BasePlugin
 | 
			
		||||
from iopaint.plugins.remove_bg import RemoveBG
 | 
			
		||||
from iopaint.schema import (
 | 
			
		||||
    GenInfoResponse,
 | 
			
		||||
    ApiConfig,
 | 
			
		||||
| 
						 | 
				
			
			@ -56,6 +56,9 @@ from iopaint.schema import (
 | 
			
		|||
    SDSampler,
 | 
			
		||||
    PluginInfo,
 | 
			
		||||
    AdjustMaskRequest,
 | 
			
		||||
    RemoveBGModel,
 | 
			
		||||
    SwitchPluginModelRequest,
 | 
			
		||||
    ModelInfo,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
 | 
			
		||||
| 
						 | 
				
			
			@ -154,11 +157,11 @@ class Api:
 | 
			
		|||
        # fmt: off
 | 
			
		||||
        self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
 | 
			
		||||
        self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
 | 
			
		||||
        self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
 | 
			
		||||
        self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
 | 
			
		||||
        self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
 | 
			
		||||
        self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
 | 
			
		||||
        self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
 | 
			
		||||
        self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
 | 
			
		||||
        self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
 | 
			
		||||
        self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
 | 
			
		||||
        self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
 | 
			
		||||
| 
						 | 
				
			
			@ -175,9 +178,6 @@ class Api:
 | 
			
		|||
    def add_api_route(self, path: str, endpoint, **kwargs):
 | 
			
		||||
        return self.app.add_api_route(path, endpoint, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def api_models(self) -> List[ModelInfo]:
 | 
			
		||||
        return self.model_manager.scan_models()
 | 
			
		||||
 | 
			
		||||
    def api_current_model(self) -> ModelInfo:
 | 
			
		||||
        return self.model_manager.current_model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -187,16 +187,28 @@ class Api:
 | 
			
		|||
        self.model_manager.switch(req.name)
 | 
			
		||||
        return self.model_manager.current_model
 | 
			
		||||
 | 
			
		||||
    def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
 | 
			
		||||
        if req.plugin_name in self.plugins:
 | 
			
		||||
            self.plugins[req.plugin_name].switch_model(req.model_name)
 | 
			
		||||
            if req.plugin_name == RemoveBG.name:
 | 
			
		||||
                self.config.remove_bg_model = req.model_name
 | 
			
		||||
 | 
			
		||||
    def api_server_config(self) -> ServerConfigResponse:
 | 
			
		||||
        return ServerConfigResponse(
 | 
			
		||||
            plugins=[
 | 
			
		||||
        plugins = []
 | 
			
		||||
        for it in self.plugins.values():
 | 
			
		||||
            plugins.append(
 | 
			
		||||
                PluginInfo(
 | 
			
		||||
                    name=it.name,
 | 
			
		||||
                    support_gen_image=it.support_gen_image,
 | 
			
		||||
                    support_gen_mask=it.support_gen_mask,
 | 
			
		||||
                )
 | 
			
		||||
                for it in self.plugins.values()
 | 
			
		||||
            ],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return ServerConfigResponse(
 | 
			
		||||
            plugins=plugins,
 | 
			
		||||
            modelInfos=self.model_manager.scan_models(),
 | 
			
		||||
            removeBGModel=self.config.remove_bg_model,
 | 
			
		||||
            removeBGModels=RemoveBGModel.values(),
 | 
			
		||||
            enableFileManager=self.file_manager is not None,
 | 
			
		||||
            enableAutoSaving=self.config.output_dir is not None,
 | 
			
		||||
            enableControlnet=self.model_manager.enable_controlnet,
 | 
			
		||||
| 
						 | 
				
			
			@ -340,6 +352,7 @@ class Api:
 | 
			
		|||
            self.config.interactive_seg_model,
 | 
			
		||||
            self.config.interactive_seg_device,
 | 
			
		||||
            self.config.enable_remove_bg,
 | 
			
		||||
            self.config.remove_bg_model,
 | 
			
		||||
            self.config.enable_anime_seg,
 | 
			
		||||
            self.config.enable_realesrgan,
 | 
			
		||||
            self.config.realesrgan_device,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,7 +9,7 @@ from typer_config import use_json_config
 | 
			
		|||
 | 
			
		||||
from iopaint.const import *
 | 
			
		||||
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
 | 
			
		||||
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel
 | 
			
		||||
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
 | 
			
		||||
 | 
			
		||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -127,6 +127,7 @@ def start(
 | 
			
		|||
    ),
 | 
			
		||||
    interactive_seg_device: Device = Option(Device.cpu),
 | 
			
		||||
    enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
 | 
			
		||||
    remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
 | 
			
		||||
    enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
 | 
			
		||||
    enable_realesrgan: bool = Option(False),
 | 
			
		||||
    realesrgan_device: Device = Option(Device.cpu),
 | 
			
		||||
| 
						 | 
				
			
			@ -183,6 +184,7 @@ def start(
 | 
			
		|||
        interactive_seg_model=interactive_seg_model,
 | 
			
		||||
        interactive_seg_device=interactive_seg_device,
 | 
			
		||||
        enable_remove_bg=enable_remove_bg,
 | 
			
		||||
        remove_bg_model=remove_bg_model,
 | 
			
		||||
        enable_anime_seg=enable_anime_seg,
 | 
			
		||||
        enable_realesrgan=enable_realesrgan,
 | 
			
		||||
        realesrgan_device=realesrgan_device,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,5 @@
 | 
			
		|||
import json
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from iopaint.schema import ApiConfig, Device, InteractiveSegModel, RealESRGANModel
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
 | 
			
		||||
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
 | 
			
		||||
| 
						 | 
				
			
			@ -57,7 +54,7 @@ CPU_TEXTENCODER_HELP = """
 | 
			
		|||
Run diffusion models text encoder on CPU to reduce vRAM usage.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
SD_CONTROLNET_CHOICES = [
 | 
			
		||||
SD_CONTROLNET_CHOICES: List[str] = [
 | 
			
		||||
    "lllyasviel/control_v11p_sd15_canny",
 | 
			
		||||
    # "lllyasviel/control_v11p_sd15_seg",
 | 
			
		||||
    "lllyasviel/control_v11p_sd15_openpose",
 | 
			
		||||
| 
						 | 
				
			
			@ -113,38 +110,9 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la
 | 
			
		|||
 | 
			
		||||
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
 | 
			
		||||
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
 | 
			
		||||
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
 | 
			
		||||
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
 | 
			
		||||
REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
 | 
			
		||||
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
 | 
			
		||||
REALESRGAN_HELP = "Enable realesrgan super resolution"
 | 
			
		||||
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
 | 
			
		||||
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
 | 
			
		||||
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
 | 
			
		||||
 | 
			
		||||
default_configs = dict(
 | 
			
		||||
    host="127.0.0.1",
 | 
			
		||||
    port=8080,
 | 
			
		||||
    model=DEFAULT_MODEL,
 | 
			
		||||
    model_dir=DEFAULT_MODEL_DIR,
 | 
			
		||||
    no_half=False,
 | 
			
		||||
    low_mem=False,
 | 
			
		||||
    cpu_offload=False,
 | 
			
		||||
    disable_nsfw_checker=False,
 | 
			
		||||
    local_files_only=False,
 | 
			
		||||
    cpu_textencoder=False,
 | 
			
		||||
    device=Device.cuda,
 | 
			
		||||
    input=None,
 | 
			
		||||
    output_dir=None,
 | 
			
		||||
    quality=95,
 | 
			
		||||
    enable_interactive_seg=False,
 | 
			
		||||
    interactive_seg_model=InteractiveSegModel.vit_b,
 | 
			
		||||
    interactive_seg_device=Device.cpu,
 | 
			
		||||
    enable_remove_bg=False,
 | 
			
		||||
    enable_anime_seg=False,
 | 
			
		||||
    enable_realesrgan=False,
 | 
			
		||||
    realesrgan_device=Device.cpu,
 | 
			
		||||
    realesrgan_model=RealESRGANModel.realesr_general_x4v3,
 | 
			
		||||
    enable_gfpgan=False,
 | 
			
		||||
    gfpgan_device=Device.cpu,
 | 
			
		||||
    enable_restoreformer=False,
 | 
			
		||||
    restoreformer_device=Device.cpu,
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ import os
 | 
			
		|||
from functools import lru_cache
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from iopaint.schema import ModelType, ModelInfo
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -15,7 +16,6 @@ from iopaint.const import (
 | 
			
		|||
    ANYTEXT_NAME,
 | 
			
		||||
)
 | 
			
		||||
from iopaint.model.original_sd_configs import get_config_files
 | 
			
		||||
from iopaint.model_info import ModelInfo, ModelType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cli_download_model(model: str):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,103 +0,0 @@
 | 
			
		|||
from typing import List
 | 
			
		||||
 | 
			
		||||
from pydantic import computed_field, BaseModel
 | 
			
		||||
 | 
			
		||||
from iopaint.const import (
 | 
			
		||||
    SDXL_CONTROLNET_CHOICES,
 | 
			
		||||
    SD2_CONTROLNET_CHOICES,
 | 
			
		||||
    SD_CONTROLNET_CHOICES,
 | 
			
		||||
    INSTRUCT_PIX2PIX_NAME,
 | 
			
		||||
    KANDINSKY22_NAME,
 | 
			
		||||
    POWERPAINT_NAME,
 | 
			
		||||
    ANYTEXT_NAME,
 | 
			
		||||
)
 | 
			
		||||
from iopaint.schema import ModelType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelInfo(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    path: str
 | 
			
		||||
    model_type: ModelType
 | 
			
		||||
    is_single_file_diffusers: bool = False
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def need_prompt(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [
 | 
			
		||||
            INSTRUCT_PIX2PIX_NAME,
 | 
			
		||||
            KANDINSKY22_NAME,
 | 
			
		||||
            POWERPAINT_NAME,
 | 
			
		||||
            ANYTEXT_NAME,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def controlnets(self) -> List[str]:
 | 
			
		||||
        if self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ]:
 | 
			
		||||
            return SDXL_CONTROLNET_CHOICES
 | 
			
		||||
        if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
 | 
			
		||||
            if "sd2" in self.name.lower():
 | 
			
		||||
                return SD2_CONTROLNET_CHOICES
 | 
			
		||||
            else:
 | 
			
		||||
                return SD_CONTROLNET_CHOICES
 | 
			
		||||
        if self.name == POWERPAINT_NAME:
 | 
			
		||||
            return SD_CONTROLNET_CHOICES
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_strength(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_outpainting(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_lcm_lora(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_controlnet(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_freeu(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [INSTRUCT_PIX2PIX_NAME]
 | 
			
		||||
| 
						 | 
				
			
			@ -8,8 +8,7 @@ from iopaint.download import scan_models
 | 
			
		|||
from iopaint.helper import switch_mps_device
 | 
			
		||||
from iopaint.model import models, ControlNet, SD, SDXL
 | 
			
		||||
from iopaint.model.utils import torch_gc, is_local_files_only
 | 
			
		||||
from iopaint.model_info import ModelInfo, ModelType
 | 
			
		||||
from iopaint.schema import InpaintRequest
 | 
			
		||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelManager:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@ def build_plugins(
 | 
			
		|||
    interactive_seg_model: InteractiveSegModel,
 | 
			
		||||
    interactive_seg_device: Device,
 | 
			
		||||
    enable_remove_bg: bool,
 | 
			
		||||
    remove_bg_model: str,
 | 
			
		||||
    enable_anime_seg: bool,
 | 
			
		||||
    enable_realesrgan: bool,
 | 
			
		||||
    realesrgan_device: Device,
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +36,7 @@ def build_plugins(
 | 
			
		|||
 | 
			
		||||
    if enable_remove_bg:
 | 
			
		||||
        logger.info(f"Initialize {RemoveBG.name} plugin")
 | 
			
		||||
        plugins[RemoveBG.name] = RemoveBG()
 | 
			
		||||
        plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
 | 
			
		||||
 | 
			
		||||
    if enable_anime_seg:
 | 
			
		||||
        logger.info(f"Initialize {AnimeSeg.name} plugin")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,3 +25,6 @@ class BasePlugin:
 | 
			
		|||
 | 
			
		||||
    def check_dep(self):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def switch_model(self, new_model_name: str):
 | 
			
		||||
        ...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,515 @@
 | 
			
		|||
# copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
 | 
			
		||||
import cv2
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import numpy as np
 | 
			
		||||
from torchvision.transforms.functional import normalize
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class REBNCONV(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
 | 
			
		||||
        super(REBNCONV, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.conv_s1 = nn.Conv2d(
 | 
			
		||||
            in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
 | 
			
		||||
        )
 | 
			
		||||
        self.bn_s1 = nn.BatchNorm2d(out_ch)
 | 
			
		||||
        self.relu_s1 = nn.ReLU(inplace=True)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        hx = x
 | 
			
		||||
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
 | 
			
		||||
 | 
			
		||||
        return xout
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
 | 
			
		||||
def _upsample_like(src, tar):
 | 
			
		||||
    src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
 | 
			
		||||
 | 
			
		||||
    return src
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### RSU-7 ###
 | 
			
		||||
class RSU7(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
 | 
			
		||||
        super(RSU7, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.in_ch = in_ch
 | 
			
		||||
        self.mid_ch = mid_ch
 | 
			
		||||
        self.out_ch = out_ch
 | 
			
		||||
 | 
			
		||||
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)  ## 1 -> 1/2
 | 
			
		||||
 | 
			
		||||
        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        b, c, h, w = x.shape
 | 
			
		||||
 | 
			
		||||
        hx = x
 | 
			
		||||
        hxin = self.rebnconvin(hx)
 | 
			
		||||
 | 
			
		||||
        hx1 = self.rebnconv1(hxin)
 | 
			
		||||
        hx = self.pool1(hx1)
 | 
			
		||||
 | 
			
		||||
        hx2 = self.rebnconv2(hx)
 | 
			
		||||
        hx = self.pool2(hx2)
 | 
			
		||||
 | 
			
		||||
        hx3 = self.rebnconv3(hx)
 | 
			
		||||
        hx = self.pool3(hx3)
 | 
			
		||||
 | 
			
		||||
        hx4 = self.rebnconv4(hx)
 | 
			
		||||
        hx = self.pool4(hx4)
 | 
			
		||||
 | 
			
		||||
        hx5 = self.rebnconv5(hx)
 | 
			
		||||
        hx = self.pool5(hx5)
 | 
			
		||||
 | 
			
		||||
        hx6 = self.rebnconv6(hx)
 | 
			
		||||
 | 
			
		||||
        hx7 = self.rebnconv7(hx6)
 | 
			
		||||
 | 
			
		||||
        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
 | 
			
		||||
        hx6dup = _upsample_like(hx6d, hx5)
 | 
			
		||||
 | 
			
		||||
        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
 | 
			
		||||
        hx5dup = _upsample_like(hx5d, hx4)
 | 
			
		||||
 | 
			
		||||
        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
 | 
			
		||||
        hx4dup = _upsample_like(hx4d, hx3)
 | 
			
		||||
 | 
			
		||||
        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
 | 
			
		||||
        hx3dup = _upsample_like(hx3d, hx2)
 | 
			
		||||
 | 
			
		||||
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
 | 
			
		||||
        hx2dup = _upsample_like(hx2d, hx1)
 | 
			
		||||
 | 
			
		||||
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
 | 
			
		||||
 | 
			
		||||
        return hx1d + hxin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### RSU-6 ###
 | 
			
		||||
class RSU6(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
 | 
			
		||||
        super(RSU6, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        hx = x
 | 
			
		||||
 | 
			
		||||
        hxin = self.rebnconvin(hx)
 | 
			
		||||
 | 
			
		||||
        hx1 = self.rebnconv1(hxin)
 | 
			
		||||
        hx = self.pool1(hx1)
 | 
			
		||||
 | 
			
		||||
        hx2 = self.rebnconv2(hx)
 | 
			
		||||
        hx = self.pool2(hx2)
 | 
			
		||||
 | 
			
		||||
        hx3 = self.rebnconv3(hx)
 | 
			
		||||
        hx = self.pool3(hx3)
 | 
			
		||||
 | 
			
		||||
        hx4 = self.rebnconv4(hx)
 | 
			
		||||
        hx = self.pool4(hx4)
 | 
			
		||||
 | 
			
		||||
        hx5 = self.rebnconv5(hx)
 | 
			
		||||
 | 
			
		||||
        hx6 = self.rebnconv6(hx5)
 | 
			
		||||
 | 
			
		||||
        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
 | 
			
		||||
        hx5dup = _upsample_like(hx5d, hx4)
 | 
			
		||||
 | 
			
		||||
        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
 | 
			
		||||
        hx4dup = _upsample_like(hx4d, hx3)
 | 
			
		||||
 | 
			
		||||
        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
 | 
			
		||||
        hx3dup = _upsample_like(hx3d, hx2)
 | 
			
		||||
 | 
			
		||||
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
 | 
			
		||||
        hx2dup = _upsample_like(hx2d, hx1)
 | 
			
		||||
 | 
			
		||||
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
 | 
			
		||||
 | 
			
		||||
        return hx1d + hxin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### RSU-5 ###
 | 
			
		||||
class RSU5(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
 | 
			
		||||
        super(RSU5, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        hx = x
 | 
			
		||||
 | 
			
		||||
        hxin = self.rebnconvin(hx)
 | 
			
		||||
 | 
			
		||||
        hx1 = self.rebnconv1(hxin)
 | 
			
		||||
        hx = self.pool1(hx1)
 | 
			
		||||
 | 
			
		||||
        hx2 = self.rebnconv2(hx)
 | 
			
		||||
        hx = self.pool2(hx2)
 | 
			
		||||
 | 
			
		||||
        hx3 = self.rebnconv3(hx)
 | 
			
		||||
        hx = self.pool3(hx3)
 | 
			
		||||
 | 
			
		||||
        hx4 = self.rebnconv4(hx)
 | 
			
		||||
 | 
			
		||||
        hx5 = self.rebnconv5(hx4)
 | 
			
		||||
 | 
			
		||||
        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
 | 
			
		||||
        hx4dup = _upsample_like(hx4d, hx3)
 | 
			
		||||
 | 
			
		||||
        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
 | 
			
		||||
        hx3dup = _upsample_like(hx3d, hx2)
 | 
			
		||||
 | 
			
		||||
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
 | 
			
		||||
        hx2dup = _upsample_like(hx2d, hx1)
 | 
			
		||||
 | 
			
		||||
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
 | 
			
		||||
 | 
			
		||||
        return hx1d + hxin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### RSU-4 ###
 | 
			
		||||
class RSU4(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
 | 
			
		||||
        super(RSU4, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        hx = x
 | 
			
		||||
 | 
			
		||||
        hxin = self.rebnconvin(hx)
 | 
			
		||||
 | 
			
		||||
        hx1 = self.rebnconv1(hxin)
 | 
			
		||||
        hx = self.pool1(hx1)
 | 
			
		||||
 | 
			
		||||
        hx2 = self.rebnconv2(hx)
 | 
			
		||||
        hx = self.pool2(hx2)
 | 
			
		||||
 | 
			
		||||
        hx3 = self.rebnconv3(hx)
 | 
			
		||||
 | 
			
		||||
        hx4 = self.rebnconv4(hx3)
 | 
			
		||||
 | 
			
		||||
        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
 | 
			
		||||
        hx3dup = _upsample_like(hx3d, hx2)
 | 
			
		||||
 | 
			
		||||
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
 | 
			
		||||
        hx2dup = _upsample_like(hx2d, hx1)
 | 
			
		||||
 | 
			
		||||
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
 | 
			
		||||
 | 
			
		||||
        return hx1d + hxin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### RSU-4F ###
 | 
			
		||||
class RSU4F(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
 | 
			
		||||
        super(RSU4F, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
 | 
			
		||||
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
 | 
			
		||||
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
 | 
			
		||||
 | 
			
		||||
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
 | 
			
		||||
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
 | 
			
		||||
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        hx = x
 | 
			
		||||
 | 
			
		||||
        hxin = self.rebnconvin(hx)
 | 
			
		||||
 | 
			
		||||
        hx1 = self.rebnconv1(hxin)
 | 
			
		||||
        hx2 = self.rebnconv2(hx1)
 | 
			
		||||
        hx3 = self.rebnconv3(hx2)
 | 
			
		||||
 | 
			
		||||
        hx4 = self.rebnconv4(hx3)
 | 
			
		||||
 | 
			
		||||
        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
 | 
			
		||||
        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
 | 
			
		||||
        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
 | 
			
		||||
 | 
			
		||||
        return hx1d + hxin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class myrebnconv(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        in_ch=3,
 | 
			
		||||
        out_ch=1,
 | 
			
		||||
        kernel_size=3,
 | 
			
		||||
        stride=1,
 | 
			
		||||
        padding=1,
 | 
			
		||||
        dilation=1,
 | 
			
		||||
        groups=1,
 | 
			
		||||
    ):
 | 
			
		||||
        super(myrebnconv, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.conv = nn.Conv2d(
 | 
			
		||||
            in_ch,
 | 
			
		||||
            out_ch,
 | 
			
		||||
            kernel_size=kernel_size,
 | 
			
		||||
            stride=stride,
 | 
			
		||||
            padding=padding,
 | 
			
		||||
            dilation=dilation,
 | 
			
		||||
            groups=groups,
 | 
			
		||||
        )
 | 
			
		||||
        self.bn = nn.BatchNorm2d(out_ch)
 | 
			
		||||
        self.rl = nn.ReLU(inplace=True)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.rl(self.bn(self.conv(x)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BriaRMBG(nn.Module):
 | 
			
		||||
    def __init__(self, in_ch=3, out_ch=1):
 | 
			
		||||
        super(BriaRMBG, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
 | 
			
		||||
        self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.stage1 = RSU7(64, 32, 64)
 | 
			
		||||
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.stage2 = RSU6(64, 32, 128)
 | 
			
		||||
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.stage3 = RSU5(128, 64, 256)
 | 
			
		||||
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.stage4 = RSU4(256, 128, 512)
 | 
			
		||||
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.stage5 = RSU4F(512, 256, 512)
 | 
			
		||||
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
 | 
			
		||||
 | 
			
		||||
        self.stage6 = RSU4F(512, 256, 512)
 | 
			
		||||
 | 
			
		||||
        # decoder
 | 
			
		||||
        self.stage5d = RSU4F(1024, 256, 512)
 | 
			
		||||
        self.stage4d = RSU4(1024, 128, 256)
 | 
			
		||||
        self.stage3d = RSU5(512, 64, 128)
 | 
			
		||||
        self.stage2d = RSU6(256, 32, 64)
 | 
			
		||||
        self.stage1d = RSU7(128, 16, 64)
 | 
			
		||||
 | 
			
		||||
        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
 | 
			
		||||
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
 | 
			
		||||
        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
 | 
			
		||||
        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
 | 
			
		||||
        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
 | 
			
		||||
        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
 | 
			
		||||
 | 
			
		||||
        # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        hx = x
 | 
			
		||||
 | 
			
		||||
        hxin = self.conv_in(hx)
 | 
			
		||||
        # hx = self.pool_in(hxin)
 | 
			
		||||
 | 
			
		||||
        # stage 1
 | 
			
		||||
        hx1 = self.stage1(hxin)
 | 
			
		||||
        hx = self.pool12(hx1)
 | 
			
		||||
 | 
			
		||||
        # stage 2
 | 
			
		||||
        hx2 = self.stage2(hx)
 | 
			
		||||
        hx = self.pool23(hx2)
 | 
			
		||||
 | 
			
		||||
        # stage 3
 | 
			
		||||
        hx3 = self.stage3(hx)
 | 
			
		||||
        hx = self.pool34(hx3)
 | 
			
		||||
 | 
			
		||||
        # stage 4
 | 
			
		||||
        hx4 = self.stage4(hx)
 | 
			
		||||
        hx = self.pool45(hx4)
 | 
			
		||||
 | 
			
		||||
        # stage 5
 | 
			
		||||
        hx5 = self.stage5(hx)
 | 
			
		||||
        hx = self.pool56(hx5)
 | 
			
		||||
 | 
			
		||||
        # stage 6
 | 
			
		||||
        hx6 = self.stage6(hx)
 | 
			
		||||
        hx6up = _upsample_like(hx6, hx5)
 | 
			
		||||
 | 
			
		||||
        # -------------------- decoder --------------------
 | 
			
		||||
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
 | 
			
		||||
        hx5dup = _upsample_like(hx5d, hx4)
 | 
			
		||||
 | 
			
		||||
        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
 | 
			
		||||
        hx4dup = _upsample_like(hx4d, hx3)
 | 
			
		||||
 | 
			
		||||
        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
 | 
			
		||||
        hx3dup = _upsample_like(hx3d, hx2)
 | 
			
		||||
 | 
			
		||||
        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
 | 
			
		||||
        hx2dup = _upsample_like(hx2d, hx1)
 | 
			
		||||
 | 
			
		||||
        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
 | 
			
		||||
 | 
			
		||||
        # side output
 | 
			
		||||
        d1 = self.side1(hx1d)
 | 
			
		||||
        d1 = _upsample_like(d1, x)
 | 
			
		||||
 | 
			
		||||
        d2 = self.side2(hx2d)
 | 
			
		||||
        d2 = _upsample_like(d2, x)
 | 
			
		||||
 | 
			
		||||
        d3 = self.side3(hx3d)
 | 
			
		||||
        d3 = _upsample_like(d3, x)
 | 
			
		||||
 | 
			
		||||
        d4 = self.side4(hx4d)
 | 
			
		||||
        d4 = _upsample_like(d4, x)
 | 
			
		||||
 | 
			
		||||
        d5 = self.side5(hx5d)
 | 
			
		||||
        d5 = _upsample_like(d5, x)
 | 
			
		||||
 | 
			
		||||
        d6 = self.side6(hx6)
 | 
			
		||||
        d6 = _upsample_like(d6, x)
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            F.sigmoid(d1),
 | 
			
		||||
            F.sigmoid(d2),
 | 
			
		||||
            F.sigmoid(d3),
 | 
			
		||||
            F.sigmoid(d4),
 | 
			
		||||
            F.sigmoid(d5),
 | 
			
		||||
            F.sigmoid(d6),
 | 
			
		||||
        ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def resize_image(image):
 | 
			
		||||
    image = image.convert("RGB")
 | 
			
		||||
    model_input_size = (1024, 1024)
 | 
			
		||||
    image = image.resize(model_input_size, Image.BILINEAR)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_briarmbg_session():
 | 
			
		||||
    from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
    net = BriaRMBG()
 | 
			
		||||
    model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
 | 
			
		||||
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
 | 
			
		||||
    net.eval()
 | 
			
		||||
    return net
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def briarmbg_process(bgr_np_image, session, only_mask=False):
 | 
			
		||||
    # prepare input
 | 
			
		||||
    orig_bgr_image = Image.fromarray(bgr_np_image)
 | 
			
		||||
    w, h = orig_im_size = orig_bgr_image.size
 | 
			
		||||
    image = resize_image(orig_bgr_image)
 | 
			
		||||
    im_np = np.array(image)
 | 
			
		||||
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
 | 
			
		||||
    im_tensor = torch.unsqueeze(im_tensor, 0)
 | 
			
		||||
    im_tensor = torch.divide(im_tensor, 255.0)
 | 
			
		||||
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        im_tensor = im_tensor.cuda()
 | 
			
		||||
 | 
			
		||||
    # inference
 | 
			
		||||
    result = session(im_tensor)
 | 
			
		||||
    # post process
 | 
			
		||||
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
 | 
			
		||||
    ma = torch.max(result)
 | 
			
		||||
    mi = torch.min(result)
 | 
			
		||||
    result = (result - mi) / (ma - mi)
 | 
			
		||||
    # image to pil
 | 
			
		||||
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
 | 
			
		||||
 | 
			
		||||
    mask = np.squeeze(im_array)
 | 
			
		||||
    if only_mask:
 | 
			
		||||
        return mask
 | 
			
		||||
 | 
			
		||||
    pil_im = Image.fromarray(mask)
 | 
			
		||||
    # paste the mask on the original image
 | 
			
		||||
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
 | 
			
		||||
    new_im.paste(orig_bgr_image, mask=pil_im)
 | 
			
		||||
    rgba_np_img = np.asarray(new_im)
 | 
			
		||||
    return rgba_np_img
 | 
			
		||||
| 
						 | 
				
			
			@ -1,8 +1,6 @@
 | 
			
		|||
import hashlib
 | 
			
		||||
import json
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from loguru import logger
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,10 +1,11 @@
 | 
			
		|||
import os
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from torch.hub import get_dir
 | 
			
		||||
 | 
			
		||||
from iopaint.plugins.base_plugin import BasePlugin
 | 
			
		||||
from iopaint.schema import RunPluginRequest
 | 
			
		||||
from iopaint.schema import RunPluginRequest, RemoveBGModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RemoveBG(BasePlugin):
 | 
			
		||||
| 
						 | 
				
			
			@ -12,32 +13,53 @@ class RemoveBG(BasePlugin):
 | 
			
		|||
    support_gen_mask = True
 | 
			
		||||
    support_gen_image = True
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
    def __init__(self, model_name):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        from rembg import new_session
 | 
			
		||||
        self.model_name = model_name
 | 
			
		||||
 | 
			
		||||
        hub_dir = get_dir()
 | 
			
		||||
        model_dir = os.path.join(hub_dir, "checkpoints")
 | 
			
		||||
        os.environ["U2NET_HOME"] = model_dir
 | 
			
		||||
 | 
			
		||||
        self.session = new_session(model_name="u2net")
 | 
			
		||||
        self._init_session(model_name)
 | 
			
		||||
 | 
			
		||||
    def _init_session(self, model_name: str):
 | 
			
		||||
        if model_name == RemoveBGModel.briaai_rmbg_1_4:
 | 
			
		||||
            from iopaint.plugins.briarmbg import (
 | 
			
		||||
                create_briarmbg_session,
 | 
			
		||||
                briarmbg_process,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.session = create_briarmbg_session()
 | 
			
		||||
            self.remove = briarmbg_process
 | 
			
		||||
        else:
 | 
			
		||||
            from rembg import new_session, remove
 | 
			
		||||
 | 
			
		||||
            self.session = new_session(model_name=model_name)
 | 
			
		||||
            self.remove = remove
 | 
			
		||||
 | 
			
		||||
    def switch_model(self, new_model_name):
 | 
			
		||||
        if self.model_name == new_model_name:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Switching removebg model from {self.model_name} to {new_model_name}"
 | 
			
		||||
        )
 | 
			
		||||
        self._init_session(new_model_name)
 | 
			
		||||
        self.model_name = new_model_name
 | 
			
		||||
 | 
			
		||||
    def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
 | 
			
		||||
        from rembg import remove
 | 
			
		||||
 | 
			
		||||
        bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
 | 
			
		||||
 | 
			
		||||
        # return BGRA image
 | 
			
		||||
        output = remove(bgr_np_img, session=self.session)
 | 
			
		||||
        output = self.remove(bgr_np_img, session=self.session)
 | 
			
		||||
        return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
 | 
			
		||||
 | 
			
		||||
    def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
 | 
			
		||||
        from rembg import remove
 | 
			
		||||
 | 
			
		||||
        bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
 | 
			
		||||
 | 
			
		||||
        # return BGR image, 255 means foreground, 0 means background
 | 
			
		||||
        output = remove(bgr_np_img, session=self.session, only_mask=True)
 | 
			
		||||
        output = self.remove(bgr_np_img, session=self.session, only_mask=True)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def check_dep(self):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,11 +5,11 @@ import sys
 | 
			
		|||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import packaging.version
 | 
			
		||||
from iopaint.schema import Device
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from rich import print
 | 
			
		||||
from typing import Dict, Any
 | 
			
		||||
 | 
			
		||||
from iopaint.const import Device
 | 
			
		||||
 | 
			
		||||
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,11 +1,117 @@
 | 
			
		|||
import json
 | 
			
		||||
import random
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional, Literal, List
 | 
			
		||||
 | 
			
		||||
from iopaint.const import (
 | 
			
		||||
    INSTRUCT_PIX2PIX_NAME,
 | 
			
		||||
    KANDINSKY22_NAME,
 | 
			
		||||
    POWERPAINT_NAME,
 | 
			
		||||
    ANYTEXT_NAME,
 | 
			
		||||
    SDXL_CONTROLNET_CHOICES,
 | 
			
		||||
    SD2_CONTROLNET_CHOICES,
 | 
			
		||||
    SD_CONTROLNET_CHOICES,
 | 
			
		||||
)
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from pydantic import BaseModel, Field, field_validator
 | 
			
		||||
from pydantic import BaseModel, Field, field_validator, computed_field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelType(str, Enum):
 | 
			
		||||
    INPAINT = "inpaint"  # LaMa, MAT...
 | 
			
		||||
    DIFFUSERS_SD = "diffusers_sd"
 | 
			
		||||
    DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
 | 
			
		||||
    DIFFUSERS_SDXL = "diffusers_sdxl"
 | 
			
		||||
    DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
 | 
			
		||||
    DIFFUSERS_OTHER = "diffusers_other"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelInfo(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    path: str
 | 
			
		||||
    model_type: ModelType
 | 
			
		||||
    is_single_file_diffusers: bool = False
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def need_prompt(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [
 | 
			
		||||
            INSTRUCT_PIX2PIX_NAME,
 | 
			
		||||
            KANDINSKY22_NAME,
 | 
			
		||||
            POWERPAINT_NAME,
 | 
			
		||||
            ANYTEXT_NAME,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def controlnets(self) -> List[str]:
 | 
			
		||||
        if self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ]:
 | 
			
		||||
            return SDXL_CONTROLNET_CHOICES
 | 
			
		||||
        if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
 | 
			
		||||
            if "sd2" in self.name.lower():
 | 
			
		||||
                return SD2_CONTROLNET_CHOICES
 | 
			
		||||
            else:
 | 
			
		||||
                return SD_CONTROLNET_CHOICES
 | 
			
		||||
        if self.name == POWERPAINT_NAME:
 | 
			
		||||
            return SD_CONTROLNET_CHOICES
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_strength(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_outpainting(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_lcm_lora(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_controlnet(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @computed_field
 | 
			
		||||
    @property
 | 
			
		||||
    def support_freeu(self) -> bool:
 | 
			
		||||
        return self.model_type in [
 | 
			
		||||
            ModelType.DIFFUSERS_SD,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL,
 | 
			
		||||
            ModelType.DIFFUSERS_SD_INPAINT,
 | 
			
		||||
            ModelType.DIFFUSERS_SDXL_INPAINT,
 | 
			
		||||
        ] or self.name in [INSTRUCT_PIX2PIX_NAME]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Choices(str, Enum):
 | 
			
		||||
| 
						 | 
				
			
			@ -20,6 +126,16 @@ class RealESRGANModel(Choices):
 | 
			
		|||
    RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RemoveBGModel(Choices):
 | 
			
		||||
    u2net = "u2net"
 | 
			
		||||
    u2netp = "u2netp"
 | 
			
		||||
    u2net_human_seg = "u2net_human_seg"
 | 
			
		||||
    u2net_cloth_seg = "u2net_cloth_seg"
 | 
			
		||||
    silueta = "silueta"
 | 
			
		||||
    isnet_general_use = "isnet-general-use"
 | 
			
		||||
    briaai_rmbg_1_4 = "briaai/RMBG-1.4"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Device(Choices):
 | 
			
		||||
    cpu = "cpu"
 | 
			
		||||
    cuda = "cuda"
 | 
			
		||||
| 
						 | 
				
			
			@ -44,15 +160,6 @@ class CV2Flag(str, Enum):
 | 
			
		|||
    INPAINT_TELEA = "INPAINT_TELEA"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelType(str, Enum):
 | 
			
		||||
    INPAINT = "inpaint"  # LaMa, MAT...
 | 
			
		||||
    DIFFUSERS_SD = "diffusers_sd"
 | 
			
		||||
    DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
 | 
			
		||||
    DIFFUSERS_SDXL = "diffusers_sdxl"
 | 
			
		||||
    DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
 | 
			
		||||
    DIFFUSERS_OTHER = "diffusers_other"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HDStrategy(str, Enum):
 | 
			
		||||
    # Use original image size
 | 
			
		||||
    ORIGINAL = "Original"
 | 
			
		||||
| 
						 | 
				
			
			@ -124,6 +231,7 @@ class ApiConfig(BaseModel):
 | 
			
		|||
    interactive_seg_model: InteractiveSegModel
 | 
			
		||||
    interactive_seg_device: Device
 | 
			
		||||
    enable_remove_bg: bool
 | 
			
		||||
    remove_bg_model: str
 | 
			
		||||
    enable_anime_seg: bool
 | 
			
		||||
    enable_realesrgan: bool
 | 
			
		||||
    realesrgan_device: Device
 | 
			
		||||
| 
						 | 
				
			
			@ -313,6 +421,9 @@ class GenInfoResponse(BaseModel):
 | 
			
		|||
 | 
			
		||||
class ServerConfigResponse(BaseModel):
 | 
			
		||||
    plugins: List[PluginInfo]
 | 
			
		||||
    modelInfos: List[ModelInfo]
 | 
			
		||||
    removeBGModel: RemoveBGModel
 | 
			
		||||
    removeBGModels: List[str]
 | 
			
		||||
    enableFileManager: bool
 | 
			
		||||
    enableAutoSaving: bool
 | 
			
		||||
    enableControlnet: bool
 | 
			
		||||
| 
						 | 
				
			
			@ -326,6 +437,11 @@ class SwitchModelRequest(BaseModel):
 | 
			
		|||
    name: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SwitchPluginModelRequest(BaseModel):
 | 
			
		||||
    plugin_name: str
 | 
			
		||||
    model_name: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
AdjustMaskOperate = Literal["expand", "shrink", "reverse"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,7 @@ from PIL import Image
 | 
			
		|||
 | 
			
		||||
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
 | 
			
		||||
from iopaint.plugins.anime_seg import AnimeSeg
 | 
			
		||||
from iopaint.schema import RunPluginRequest
 | 
			
		||||
from iopaint.schema import RunPluginRequest, RemoveBGModel
 | 
			
		||||
from iopaint.tests.utils import check_device, current_dir, save_dir
 | 
			
		||||
 | 
			
		||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,7 +34,7 @@ def _save(img, name):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_remove_bg():
 | 
			
		||||
    model = RemoveBG()
 | 
			
		||||
    model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4)
 | 
			
		||||
    rgba_np_img = model.gen_image(
 | 
			
		||||
        rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,14 @@
 | 
			
		|||
import json
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from iopaint.schema import (
 | 
			
		||||
    Device,
 | 
			
		||||
    InteractiveSegModel,
 | 
			
		||||
    RemoveBGModel,
 | 
			
		||||
    RealESRGANModel,
 | 
			
		||||
    ApiConfig,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -15,6 +25,37 @@ from iopaint.const import *
 | 
			
		|||
_config_file: Path = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
default_configs = dict(
 | 
			
		||||
    host="127.0.0.1",
 | 
			
		||||
    port=8080,
 | 
			
		||||
    model=DEFAULT_MODEL,
 | 
			
		||||
    model_dir=DEFAULT_MODEL_DIR,
 | 
			
		||||
    no_half=False,
 | 
			
		||||
    low_mem=False,
 | 
			
		||||
    cpu_offload=False,
 | 
			
		||||
    disable_nsfw_checker=False,
 | 
			
		||||
    local_files_only=False,
 | 
			
		||||
    cpu_textencoder=False,
 | 
			
		||||
    device=Device.cuda,
 | 
			
		||||
    input=None,
 | 
			
		||||
    output_dir=None,
 | 
			
		||||
    quality=95,
 | 
			
		||||
    enable_interactive_seg=False,
 | 
			
		||||
    interactive_seg_model=InteractiveSegModel.vit_b,
 | 
			
		||||
    interactive_seg_device=Device.cpu,
 | 
			
		||||
    enable_remove_bg=False,
 | 
			
		||||
    remove_bg_model=RemoveBGModel.u2net,
 | 
			
		||||
    enable_anime_seg=False,
 | 
			
		||||
    enable_realesrgan=False,
 | 
			
		||||
    realesrgan_device=Device.cpu,
 | 
			
		||||
    realesrgan_model=RealESRGANModel.realesr_general_x4v3,
 | 
			
		||||
    enable_gfpgan=False,
 | 
			
		||||
    gfpgan_device=Device.cpu,
 | 
			
		||||
    enable_restoreformer=False,
 | 
			
		||||
    restoreformer_device=Device.cpu,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebConfig(ApiConfig):
 | 
			
		||||
    model_dir: str = DEFAULT_MODEL_DIR
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -50,6 +91,7 @@ def save_config(
 | 
			
		|||
    interactive_seg_model,
 | 
			
		||||
    interactive_seg_device,
 | 
			
		||||
    enable_remove_bg,
 | 
			
		||||
    remove_bg_model,
 | 
			
		||||
    enable_anime_seg,
 | 
			
		||||
    enable_realesrgan,
 | 
			
		||||
    realesrgan_device,
 | 
			
		||||
| 
						 | 
				
			
			@ -115,7 +157,7 @@ def main(config_file: Path):
 | 
			
		|||
                with gr.Row():
 | 
			
		||||
                    recommend_model = gr.Dropdown(
 | 
			
		||||
                        ["lama", "mat", "migan"] + DIFFUSION_MODELS,
 | 
			
		||||
                        label="Recommend Models",
 | 
			
		||||
                        label="Recommended Models",
 | 
			
		||||
                    )
 | 
			
		||||
                    downloaded_model = gr.Dropdown(
 | 
			
		||||
                        downloaded_models, label="Downloaded Models"
 | 
			
		||||
| 
						 | 
				
			
			@ -179,6 +221,11 @@ def main(config_file: Path):
 | 
			
		|||
                    enable_remove_bg = gr.Checkbox(
 | 
			
		||||
                        init_config.enable_remove_bg, label=REMOVE_BG_HELP
 | 
			
		||||
                    )
 | 
			
		||||
                    remove_bg_model = gr.Radio(
 | 
			
		||||
                        RemoveBGModel.values(),
 | 
			
		||||
                        label="Remove bg model",
 | 
			
		||||
                        value=init_config.remove_bg_model,
 | 
			
		||||
                    )
 | 
			
		||||
                with gr.Row():
 | 
			
		||||
                    enable_anime_seg = gr.Checkbox(
 | 
			
		||||
                        init_config.enable_anime_seg, label=ANIMESEG_HELP
 | 
			
		||||
| 
						 | 
				
			
			@ -241,6 +288,7 @@ def main(config_file: Path):
 | 
			
		|||
                interactive_seg_model,
 | 
			
		||||
                interactive_seg_device,
 | 
			
		||||
                enable_remove_bg,
 | 
			
		||||
                remove_bg_model,
 | 
			
		||||
                enable_anime_seg,
 | 
			
		||||
                enable_realesrgan,
 | 
			
		||||
                realesrgan_device,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
 | 
			
		|||
import { useEffect, useState } from "react"
 | 
			
		||||
import { cn } from "@/lib/utils"
 | 
			
		||||
import { useQuery } from "@tanstack/react-query"
 | 
			
		||||
import { fetchModelInfos, switchModel } from "@/lib/api"
 | 
			
		||||
import { ModelInfo } from "@/lib/types"
 | 
			
		||||
import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
 | 
			
		||||
import { ModelInfo, PluginName } from "@/lib/types"
 | 
			
		||||
import { useStore } from "@/lib/states"
 | 
			
		||||
import { ScrollArea } from "./ui/scroll-area"
 | 
			
		||||
import { useToast } from "./ui/use-toast"
 | 
			
		||||
| 
						 | 
				
			
			@ -39,6 +39,14 @@ import {
 | 
			
		|||
  MODEL_TYPE_OTHER,
 | 
			
		||||
} from "@/lib/const"
 | 
			
		||||
import useHotKey from "@/hooks/useHotkey"
 | 
			
		||||
import {
 | 
			
		||||
  Select,
 | 
			
		||||
  SelectContent,
 | 
			
		||||
  SelectGroup,
 | 
			
		||||
  SelectItem,
 | 
			
		||||
  SelectTrigger,
 | 
			
		||||
  SelectValue,
 | 
			
		||||
} from "./ui/select"
 | 
			
		||||
 | 
			
		||||
const formSchema = z.object({
 | 
			
		||||
  enableFileManager: z.boolean(),
 | 
			
		||||
| 
						 | 
				
			
			@ -48,42 +56,45 @@ const formSchema = z.object({
 | 
			
		|||
  enableManualInpainting: z.boolean(),
 | 
			
		||||
  enableUploadMask: z.boolean(),
 | 
			
		||||
  enableAutoExtractPrompt: z.boolean(),
 | 
			
		||||
  removeBGModel: z.string(),
 | 
			
		||||
})
 | 
			
		||||
 | 
			
		||||
const TAB_GENERAL = "General"
 | 
			
		||||
const TAB_MODEL = "Model"
 | 
			
		||||
const TAB_PLUGINS = "Plugins"
 | 
			
		||||
// const TAB_FILE_MANAGER = "File Manager"
 | 
			
		||||
 | 
			
		||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
 | 
			
		||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
 | 
			
		||||
 | 
			
		||||
export function SettingsDialog() {
 | 
			
		||||
  const [open, toggleOpen] = useToggle(false)
 | 
			
		||||
  const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
 | 
			
		||||
  const [tab, setTab] = useState(TAB_MODEL)
 | 
			
		||||
  const [
 | 
			
		||||
    updateAppState,
 | 
			
		||||
    settings,
 | 
			
		||||
    updateSettings,
 | 
			
		||||
    fileManagerState,
 | 
			
		||||
    updateFileManagerState,
 | 
			
		||||
    setAppModel,
 | 
			
		||||
    setServerConfig,
 | 
			
		||||
  ] = useStore((state) => [
 | 
			
		||||
    state.updateAppState,
 | 
			
		||||
    state.settings,
 | 
			
		||||
    state.updateSettings,
 | 
			
		||||
    state.fileManagerState,
 | 
			
		||||
    state.updateFileManagerState,
 | 
			
		||||
    state.setModel,
 | 
			
		||||
    state.setServerConfig,
 | 
			
		||||
  ])
 | 
			
		||||
  const { toast } = useToast()
 | 
			
		||||
  const [model, setModel] = useState<ModelInfo>(settings.model)
 | 
			
		||||
  const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
 | 
			
		||||
  const openModelSwitching = modelSwitchingTexts.length > 0
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    setModel(settings.model)
 | 
			
		||||
  }, [settings.model])
 | 
			
		||||
 | 
			
		||||
  const { data: modelInfos, status } = useQuery({
 | 
			
		||||
    queryKey: ["modelInfos"],
 | 
			
		||||
    queryFn: fetchModelInfos,
 | 
			
		||||
  const { data: serverConfig, status } = useQuery({
 | 
			
		||||
    queryKey: ["serverConfig"],
 | 
			
		||||
    queryFn: getServerConfig,
 | 
			
		||||
  })
 | 
			
		||||
 | 
			
		||||
  // 1. Define your form.
 | 
			
		||||
| 
						 | 
				
			
			@ -96,9 +107,17 @@ export function SettingsDialog() {
 | 
			
		|||
      enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
 | 
			
		||||
      inputDirectory: fileManagerState.inputDirectory,
 | 
			
		||||
      outputDirectory: fileManagerState.outputDirectory,
 | 
			
		||||
      removeBGModel: serverConfig?.removeBGModel,
 | 
			
		||||
    },
 | 
			
		||||
  })
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    if (serverConfig) {
 | 
			
		||||
      setServerConfig(serverConfig)
 | 
			
		||||
      form.setValue("removeBGModel", serverConfig.removeBGModel)
 | 
			
		||||
    }
 | 
			
		||||
  }, [form, serverConfig])
 | 
			
		||||
 | 
			
		||||
  async function onSubmit(values: z.infer<typeof formSchema>) {
 | 
			
		||||
    // Do something with the form values. ✅ This will be type-safe and validated.
 | 
			
		||||
    updateSettings({
 | 
			
		||||
| 
						 | 
				
			
			@ -109,29 +128,67 @@ export function SettingsDialog() {
 | 
			
		|||
    })
 | 
			
		||||
 | 
			
		||||
    // TODO: validate input/output Directory
 | 
			
		||||
    updateFileManagerState({
 | 
			
		||||
      inputDirectory: values.inputDirectory,
 | 
			
		||||
      outputDirectory: values.outputDirectory,
 | 
			
		||||
    })
 | 
			
		||||
    if (model.name !== settings.model.name) {
 | 
			
		||||
      toggleOpenModelSwitching()
 | 
			
		||||
      updateAppState({ disableShortCuts: true })
 | 
			
		||||
      try {
 | 
			
		||||
        const newModel = await switchModel(model.name)
 | 
			
		||||
        toast({
 | 
			
		||||
          title: `Switch to ${newModel.name} success`,
 | 
			
		||||
        })
 | 
			
		||||
        setAppModel(model)
 | 
			
		||||
      } catch (error: any) {
 | 
			
		||||
        toast({
 | 
			
		||||
          variant: "destructive",
 | 
			
		||||
          title: `Switch to ${model.name} failed: ${error}`,
 | 
			
		||||
        })
 | 
			
		||||
        setModel(settings.model)
 | 
			
		||||
      } finally {
 | 
			
		||||
        toggleOpenModelSwitching()
 | 
			
		||||
        updateAppState({ disableShortCuts: false })
 | 
			
		||||
    // updateFileManagerState({
 | 
			
		||||
    //   inputDirectory: values.inputDirectory,
 | 
			
		||||
    //   outputDirectory: values.outputDirectory,
 | 
			
		||||
    // })
 | 
			
		||||
 | 
			
		||||
    const shouldSwitchModel = model.name !== settings.model.name
 | 
			
		||||
    const shouldSwitchRemoveBGModel =
 | 
			
		||||
      serverConfig?.removeBGModel !== values.removeBGModel
 | 
			
		||||
    const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
 | 
			
		||||
 | 
			
		||||
    if (showModelSwitching) {
 | 
			
		||||
      const newModelSwitchingTexts: string[] = []
 | 
			
		||||
      if (shouldSwitchModel) {
 | 
			
		||||
        newModelSwitchingTexts.push(
 | 
			
		||||
          `Switching model from ${settings.model.name} to ${model.name}`
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      if (shouldSwitchRemoveBGModel) {
 | 
			
		||||
        newModelSwitchingTexts.push(
 | 
			
		||||
          `Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      setModelSwitchingTexts(newModelSwitchingTexts)
 | 
			
		||||
 | 
			
		||||
      updateAppState({ disableShortCuts: true })
 | 
			
		||||
 | 
			
		||||
      if (shouldSwitchModel) {
 | 
			
		||||
        try {
 | 
			
		||||
          const newModel = await switchModel(model.name)
 | 
			
		||||
          toast({
 | 
			
		||||
            title: `Switch to ${newModel.name} success`,
 | 
			
		||||
          })
 | 
			
		||||
          setAppModel(model)
 | 
			
		||||
        } catch (error: any) {
 | 
			
		||||
          toast({
 | 
			
		||||
            variant: "destructive",
 | 
			
		||||
            title: `Switch to ${model.name} failed: ${error}`,
 | 
			
		||||
          })
 | 
			
		||||
          setModel(settings.model)
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (shouldSwitchRemoveBGModel) {
 | 
			
		||||
        try {
 | 
			
		||||
          const res = await switchPluginModel(
 | 
			
		||||
            PluginName.RemoveBG,
 | 
			
		||||
            values.removeBGModel
 | 
			
		||||
          )
 | 
			
		||||
          if (res.status !== 200) {
 | 
			
		||||
            throw new Error(res.statusText)
 | 
			
		||||
          }
 | 
			
		||||
        } catch (error: any) {
 | 
			
		||||
          toast({
 | 
			
		||||
            variant: "destructive",
 | 
			
		||||
            title: `Switch removebg model to ${model.name} failed: ${error}`,
 | 
			
		||||
          })
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      setModelSwitchingTexts([])
 | 
			
		||||
      updateAppState({ disableShortCuts: false })
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -143,7 +200,17 @@ export function SettingsDialog() {
 | 
			
		|||
        onSubmit(form.getValues())
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    [open, form, model]
 | 
			
		||||
    [open, form, model, serverConfig]
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  if (status !== "success") {
 | 
			
		||||
    return <></>
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const modelInfos = serverConfig.modelInfos
 | 
			
		||||
  const plugins = serverConfig.plugins
 | 
			
		||||
  const removeBGEnabled = plugins.some(
 | 
			
		||||
    (plugin) => plugin.name === PluginName.RemoveBG
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  function onOpenChange(value: boolean) {
 | 
			
		||||
| 
						 | 
				
			
			@ -186,10 +253,6 @@ export function SettingsDialog() {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  function renderModelSettings() {
 | 
			
		||||
    if (status !== "success") {
 | 
			
		||||
      return <></>
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let defaultTab = MODEL_TYPE_INPAINT
 | 
			
		||||
    for (let info of modelInfos) {
 | 
			
		||||
      if (model.name === info.name) {
 | 
			
		||||
| 
						 | 
				
			
			@ -356,6 +419,44 @@ export function SettingsDialog() {
 | 
			
		|||
    )
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  function renderPluginsSettings() {
 | 
			
		||||
    return (
 | 
			
		||||
      <div className="space-y-4 w-[510px]">
 | 
			
		||||
        <FormField
 | 
			
		||||
          control={form.control}
 | 
			
		||||
          name="removeBGModel"
 | 
			
		||||
          render={({ field }) => (
 | 
			
		||||
            <FormItem className="flex items-center justify-between">
 | 
			
		||||
              <div className="space-y-0.5">
 | 
			
		||||
                <FormLabel>Remove Background</FormLabel>
 | 
			
		||||
                <FormDescription>Remove background model</FormDescription>
 | 
			
		||||
              </div>
 | 
			
		||||
              <Select
 | 
			
		||||
                onValueChange={field.onChange}
 | 
			
		||||
                defaultValue={field.value}
 | 
			
		||||
                disabled={!removeBGEnabled}
 | 
			
		||||
              >
 | 
			
		||||
                <FormControl>
 | 
			
		||||
                  <SelectTrigger className="w-[200px]">
 | 
			
		||||
                    <SelectValue placeholder="Select removebg model" />
 | 
			
		||||
                  </SelectTrigger>
 | 
			
		||||
                </FormControl>
 | 
			
		||||
                <SelectContent align="end">
 | 
			
		||||
                  <SelectGroup>
 | 
			
		||||
                    {serverConfig?.removeBGModels.map((model) => (
 | 
			
		||||
                      <SelectItem key={model} value={model}>
 | 
			
		||||
                        {model}
 | 
			
		||||
                      </SelectItem>
 | 
			
		||||
                    ))}
 | 
			
		||||
                  </SelectGroup>
 | 
			
		||||
                </SelectContent>
 | 
			
		||||
              </Select>
 | 
			
		||||
            </FormItem>
 | 
			
		||||
          )}
 | 
			
		||||
        />
 | 
			
		||||
      </div>
 | 
			
		||||
    )
 | 
			
		||||
  }
 | 
			
		||||
  // function renderFileManagerSettings() {
 | 
			
		||||
  //   return (
 | 
			
		||||
  //     <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
 | 
			
		||||
| 
						 | 
				
			
			@ -446,7 +547,9 @@ export function SettingsDialog() {
 | 
			
		|||
                <span className="sr-only">Loading...</span>
 | 
			
		||||
              </div>
 | 
			
		||||
 | 
			
		||||
              <div>Switching to {model.name}</div>
 | 
			
		||||
              {modelSwitchingTexts.map((text, index) => (
 | 
			
		||||
                <div key={index}>{text}</div>
 | 
			
		||||
              ))}
 | 
			
		||||
            </div>
 | 
			
		||||
            {/* </AlertDialogDescription> */}
 | 
			
		||||
          </AlertDialogHeader>
 | 
			
		||||
| 
						 | 
				
			
			@ -473,6 +576,7 @@ export function SettingsDialog() {
 | 
			
		|||
                <Button
 | 
			
		||||
                  key={item}
 | 
			
		||||
                  variant="ghost"
 | 
			
		||||
                  disabled={item === TAB_PLUGINS && !removeBGEnabled}
 | 
			
		||||
                  onClick={() => setTab(item)}
 | 
			
		||||
                  className={cn(
 | 
			
		||||
                    tab === item ? "bg-muted " : "hover:bg-muted",
 | 
			
		||||
| 
						 | 
				
			
			@ -489,6 +593,7 @@ export function SettingsDialog() {
 | 
			
		|||
                <form onSubmit={form.handleSubmit(onSubmit)}>
 | 
			
		||||
                  {tab === TAB_MODEL ? renderModelSettings() : <></>}
 | 
			
		||||
                  {tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
 | 
			
		||||
                  {tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
 | 
			
		||||
                  {/* {tab === TAB_FILE_MANAGER ? (
 | 
			
		||||
                    renderFileManagerSettings()
 | 
			
		||||
                  ) : (
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@ export default function useInputImage() {
 | 
			
		|||
    fetch(`${API_ENDPOINT}/inputimage`, { headers })
 | 
			
		||||
      .then(async (res) => {
 | 
			
		||||
        if (!res.ok) {
 | 
			
		||||
          throw new Error("No input image found")
 | 
			
		||||
          return
 | 
			
		||||
        }
 | 
			
		||||
        const filename = res.headers
 | 
			
		||||
          .get("content-disposition")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -104,15 +104,18 @@ export async function switchModel(name: string): Promise<ModelInfo> {
 | 
			
		|||
  return res.data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function switchPluginModel(
 | 
			
		||||
  plugin_name: string,
 | 
			
		||||
  model_name: string
 | 
			
		||||
) {
 | 
			
		||||
  return api.post(`/switch_plugin_model`, { plugin_name, model_name })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function currentModel(): Promise<ModelInfo> {
 | 
			
		||||
  const res = await api.get("/model")
 | 
			
		||||
  return res.data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function fetchModelInfos(): Promise<ModelInfo[]> {
 | 
			
		||||
  return api.get("/models").then((response) => response.data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function runPlugin(
 | 
			
		||||
  genMask: boolean,
 | 
			
		||||
  name: string,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,9 @@ export interface PluginInfo {
 | 
			
		|||
 | 
			
		||||
export interface ServerConfig {
 | 
			
		||||
  plugins: PluginInfo[]
 | 
			
		||||
  modelInfos: ModelInfo[]
 | 
			
		||||
  removeBGModel: string
 | 
			
		||||
  removeBGModels: string[]
 | 
			
		||||
  enableFileManager: boolean
 | 
			
		||||
  enableAutoSaving: boolean
 | 
			
		||||
  enableControlnet: boolean
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue