switch controlnet in webui
This commit is contained in:
		
							parent
							
								
									0363472adc
								
							
						
					
					
						commit
						3eef8f4dae
					
				| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
import { PluginName } from '../components/Plugins/Plugins'
 | 
			
		||||
import { Rect, Settings } from '../store/Atoms'
 | 
			
		||||
import { ControlNetMethodMap, Rect, Settings } from '../store/Atoms'
 | 
			
		||||
import { dataURItoBlob, loadImage, srcToFile } from '../utils'
 | 
			
		||||
 | 
			
		||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
 | 
			
		||||
| 
						 | 
				
			
			@ -92,6 +92,10 @@ export default async function inpaint(
 | 
			
		|||
    'controlnet_conditioning_scale',
 | 
			
		||||
    settings.controlnetConditioningScale.toString()
 | 
			
		||||
  )
 | 
			
		||||
  fd.append(
 | 
			
		||||
    'controlnet_method',
 | 
			
		||||
    ControlNetMethodMap[settings.controlnetMethod.toString()]
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  try {
 | 
			
		||||
    const res = await fetch(`${API_ENDPOINT}/inpaint`, {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'
 | 
			
		|||
import * as PopoverPrimitive from '@radix-ui/react-popover'
 | 
			
		||||
import { useToggle } from 'react-use'
 | 
			
		||||
import {
 | 
			
		||||
  ControlNetMethod,
 | 
			
		||||
  isControlNetState,
 | 
			
		||||
  isInpaintingState,
 | 
			
		||||
  negativePropmtState,
 | 
			
		||||
| 
						 | 
				
			
			@ -47,6 +48,44 @@ const SidePanel = () => {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const renderConterNetSetting = () => {
 | 
			
		||||
    return (
 | 
			
		||||
      <>
 | 
			
		||||
        <SettingBlock
 | 
			
		||||
          className="sub-setting-block"
 | 
			
		||||
          title="ControlNet"
 | 
			
		||||
          input={
 | 
			
		||||
            <Selector
 | 
			
		||||
              width={80}
 | 
			
		||||
              value={setting.controlnetMethod as string}
 | 
			
		||||
              options={Object.values(ControlNetMethod)}
 | 
			
		||||
              onChange={val => {
 | 
			
		||||
                const method = val as ControlNetMethod
 | 
			
		||||
                setSettingState(old => {
 | 
			
		||||
                  return { ...old, controlnetMethod: method }
 | 
			
		||||
                })
 | 
			
		||||
              }}
 | 
			
		||||
            />
 | 
			
		||||
          }
 | 
			
		||||
        />
 | 
			
		||||
 | 
			
		||||
        <NumberInputSetting
 | 
			
		||||
          title="ControlNet Weight"
 | 
			
		||||
          width={INPUT_WIDTH}
 | 
			
		||||
          allowFloat
 | 
			
		||||
          value={`${setting.controlnetConditioningScale}`}
 | 
			
		||||
          desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
 | 
			
		||||
          onValue={value => {
 | 
			
		||||
            const val = value.length === 0 ? 0 : parseFloat(value)
 | 
			
		||||
            setSettingState(old => {
 | 
			
		||||
              return { ...old, controlnetConditioningScale: val }
 | 
			
		||||
            })
 | 
			
		||||
          }}
 | 
			
		||||
        />
 | 
			
		||||
      </>
 | 
			
		||||
    )
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <div className="side-panel">
 | 
			
		||||
      <PopoverPrimitive.Root open={open}>
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +97,8 @@ const SidePanel = () => {
 | 
			
		|||
        </PopoverPrimitive.Trigger>
 | 
			
		||||
        <PopoverPrimitive.Portal>
 | 
			
		||||
          <PopoverPrimitive.Content className="side-panel-content">
 | 
			
		||||
            {isControlNet && renderConterNetSetting()}
 | 
			
		||||
 | 
			
		||||
            <SettingBlock
 | 
			
		||||
              title="Croper"
 | 
			
		||||
              input={
 | 
			
		||||
| 
						 | 
				
			
			@ -117,22 +158,6 @@ const SidePanel = () => {
 | 
			
		|||
              }}
 | 
			
		||||
            />
 | 
			
		||||
 | 
			
		||||
            {isControlNet && (
 | 
			
		||||
              <NumberInputSetting
 | 
			
		||||
                title="ControlNet Weight"
 | 
			
		||||
                width={INPUT_WIDTH}
 | 
			
		||||
                allowFloat
 | 
			
		||||
                value={`${setting.controlnetConditioningScale}`}
 | 
			
		||||
                desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
 | 
			
		||||
                onValue={value => {
 | 
			
		||||
                  const val = value.length === 0 ? 0 : parseFloat(value)
 | 
			
		||||
                  setSettingState(old => {
 | 
			
		||||
                    return { ...old, controlnetConditioningScale: val }
 | 
			
		||||
                  })
 | 
			
		||||
                }}
 | 
			
		||||
              />
 | 
			
		||||
            )}
 | 
			
		||||
 | 
			
		||||
            <NumberInputSetting
 | 
			
		||||
              title="Mask Blur"
 | 
			
		||||
              width={INPUT_WIDTH}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -440,6 +440,7 @@ export interface Settings {
 | 
			
		|||
 | 
			
		||||
  // ControlNet
 | 
			
		||||
  controlnetConditioningScale: number
 | 
			
		||||
  controlnetMethod: string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const defaultHDSettings: ModelsHDSettings = {
 | 
			
		||||
| 
						 | 
				
			
			@ -546,6 +547,18 @@ export enum SDSampler {
 | 
			
		|||
  uni_pc = 'uni_pc',
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export enum ControlNetMethod {
 | 
			
		||||
  canny = 'canny',
 | 
			
		||||
  inpaint = 'inpaint',
 | 
			
		||||
  openpose = 'openpose',
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export const ControlNetMethodMap: any = {
 | 
			
		||||
  canny: 'control_v11p_sd15_canny',
 | 
			
		||||
  inpaint: 'control_v11p_sd15_inpaint',
 | 
			
		||||
  openpose: 'control_v11p_sd15_openpose',
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export enum SDMode {
 | 
			
		||||
  text2img = 'text2img',
 | 
			
		||||
  img2img = 'img2img',
 | 
			
		||||
| 
						 | 
				
			
			@ -597,7 +610,8 @@ export const settingStateDefault: Settings = {
 | 
			
		|||
  p2pGuidanceScale: 7.5,
 | 
			
		||||
 | 
			
		||||
  // ControlNet
 | 
			
		||||
  controlnetConditioningScale: 0.4,
 | 
			
		||||
  controlnetConditioningScale: 1.0,
 | 
			
		||||
  controlnetMethod: ControlNetMethod.canny,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const localStorageEffect =
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -55,6 +55,7 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
 | 
			
		|||
SD_CONTROLNET_HELP = """
 | 
			
		||||
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
 | 
			
		||||
"""
 | 
			
		||||
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
 | 
			
		||||
SD_CONTROLNET_CHOICES = [
 | 
			
		||||
    "control_v11p_sd15_canny",
 | 
			
		||||
    "control_v11p_sd15_openpose",
 | 
			
		||||
| 
						 | 
				
			
			@ -133,6 +134,7 @@ class Config(BaseModel):
 | 
			
		|||
    model: str = DEFAULT_MODEL
 | 
			
		||||
    sd_local_model_path: str = None
 | 
			
		||||
    sd_controlnet: bool = False
 | 
			
		||||
    sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
 | 
			
		||||
    device: str = DEFAULT_DEVICE
 | 
			
		||||
    gui: bool = False
 | 
			
		||||
    no_gui_auto_close: bool = False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,7 +6,12 @@ import torch
 | 
			
		|||
import numpy as np
 | 
			
		||||
from loguru import logger
 | 
			
		||||
 | 
			
		||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device
 | 
			
		||||
from lama_cleaner.helper import (
 | 
			
		||||
    boxes_from_mask,
 | 
			
		||||
    resize_max_size,
 | 
			
		||||
    pad_img_to_modulo,
 | 
			
		||||
    switch_mps_device,
 | 
			
		||||
)
 | 
			
		||||
from lama_cleaner.schema import Config, HDStrategy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -199,7 +204,9 @@ class InpaintModel:
 | 
			
		|||
 | 
			
		||||
            # only calculate histograms for non-masked parts
 | 
			
		||||
            source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
 | 
			
		||||
            reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
 | 
			
		||||
            reference_histogram, _ = np.histogram(
 | 
			
		||||
                reference_channel[mask == 0], 256, [0, 256]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            source_cdf = self._calculate_cdf(source_histogram)
 | 
			
		||||
            reference_cdf = self._calculate_cdf(reference_histogram)
 | 
			
		||||
| 
						 | 
				
			
			@ -273,9 +280,10 @@ class DiffusionInpaintModel(InpaintModel):
 | 
			
		|||
        origin_size = image.shape[:2]
 | 
			
		||||
        downsize_image = resize_max_size(image, size_limit=longer_side_length)
 | 
			
		||||
        downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
 | 
			
		||||
        )
 | 
			
		||||
        if config.sd_scale != 1:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
 | 
			
		||||
            )
 | 
			
		||||
        inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
 | 
			
		||||
        # only paste masked area result
 | 
			
		||||
        inpaint_result = cv2.resize(
 | 
			
		||||
| 
						 | 
				
			
			@ -284,5 +292,7 @@ class DiffusionInpaintModel(InpaintModel):
 | 
			
		|||
            interpolation=cv2.INTER_CUBIC,
 | 
			
		||||
        )
 | 
			
		||||
        original_pixel_indices = mask < 127
 | 
			
		||||
        inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
 | 
			
		||||
        inpaint_result[original_pixel_indices] = image[:, :, ::-1][
 | 
			
		||||
            original_pixel_indices
 | 
			
		||||
        ]
 | 
			
		||||
        return inpaint_result
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -203,6 +203,7 @@ class ControlNet(DiffusionInpaintModel):
 | 
			
		|||
                negative_prompt=config.negative_prompt,
 | 
			
		||||
                generator=torch.manual_seed(config.sd_seed),
 | 
			
		||||
                output_type="np.array",
 | 
			
		||||
                callback=self.callback
 | 
			
		||||
            ).images[0]
 | 
			
		||||
        else:
 | 
			
		||||
            if "canny" in self.sd_controlnet_method:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -455,7 +455,7 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
 | 
			
		|||
        timesteps = self.scheduler.timesteps
 | 
			
		||||
 | 
			
		||||
        # 6. Prepare latent variables
 | 
			
		||||
        num_channels_latents = self.controlnet.in_channels
 | 
			
		||||
        num_channels_latents = self.controlnet.config.in_channels
 | 
			
		||||
        latents = self.prepare_latents(
 | 
			
		||||
            batch_size * num_images_per_prompt,
 | 
			
		||||
            num_channels_latents,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,8 @@
 | 
			
		|||
import torch
 | 
			
		||||
import gc
 | 
			
		||||
 | 
			
		||||
from loguru import logger
 | 
			
		||||
 | 
			
		||||
from lama_cleaner.const import SD15_MODELS
 | 
			
		||||
from lama_cleaner.helper import switch_mps_device
 | 
			
		||||
from lama_cleaner.model.controlnet import ControlNet
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +60,7 @@ class ModelManager:
 | 
			
		|||
            raise NotImplementedError(f"Not supported model: {name}")
 | 
			
		||||
 | 
			
		||||
    def __call__(self, image, mask, config: Config):
 | 
			
		||||
        self.switch_controlnet_method(control_method=config.controlnet_method)
 | 
			
		||||
        return self.model(image, mask, config)
 | 
			
		||||
 | 
			
		||||
    def switch(self, new_name: str, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -86,7 +89,9 @@ class ModelManager:
 | 
			
		|||
        del self.model
 | 
			
		||||
        torch_gc()
 | 
			
		||||
 | 
			
		||||
        old_method = self.kwargs["sd_controlnet_method"]
 | 
			
		||||
        self.kwargs["sd_controlnet_method"] = control_method
 | 
			
		||||
        self.model = self.init_model(
 | 
			
		||||
            self.name, switch_mps_device(self.name, self.device), **self.kwargs
 | 
			
		||||
        )
 | 
			
		||||
        logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,7 +40,7 @@ def parse_args():
 | 
			
		|||
    parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--sd-controlnet-method",
 | 
			
		||||
        default="control_v11p_sd15_inpaint",
 | 
			
		||||
        default=DEFAULT_CONTROLNET_METHOD,
 | 
			
		||||
        choices=SD_CONTROLNET_CHOICES,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,3 +4,4 @@ from .realesrgan import RealESRGANUpscaler
 | 
			
		|||
from .gfpgan_plugin import GFPGANPlugin
 | 
			
		||||
from .restoreformer import RestoreFormerPlugin
 | 
			
		||||
from .gif import MakeGIF
 | 
			
		||||
from .anime_seg import AnimeSeg
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,4 +96,5 @@ class Config(BaseModel):
 | 
			
		|||
    p2p_guidance_scale: float = 7.5
 | 
			
		||||
 | 
			
		||||
    # ControlNet
 | 
			
		||||
    controlnet_conditioning_scale: float = 0.4
 | 
			
		||||
    controlnet_conditioning_scale: float = 1.0
 | 
			
		||||
    controlnet_method: str = "control_v11p_sd15_canny"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,6 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
import asyncio
 | 
			
		||||
import hashlib
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
 | 
			
		||||
import hashlib
 | 
			
		||||
 | 
			
		||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -32,6 +29,7 @@ from lama_cleaner.plugins import (
 | 
			
		|||
    MakeGIF,
 | 
			
		||||
    GFPGANPlugin,
 | 
			
		||||
    RestoreFormerPlugin,
 | 
			
		||||
    AnimeSeg,
 | 
			
		||||
)
 | 
			
		||||
from lama_cleaner.schema import Config
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -84,7 +82,15 @@ BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
 | 
			
		|||
 | 
			
		||||
class NoFlaskwebgui(logging.Filter):
 | 
			
		||||
    def filter(self, record):
 | 
			
		||||
        return "flaskwebgui-keep-server-alive" not in record.getMessage()
 | 
			
		||||
        msg = record.getMessage()
 | 
			
		||||
        if "Running on http:" in msg:
 | 
			
		||||
            print(msg[msg.index("Running on http:") :])
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            "flaskwebgui-keep-server-alive" not in msg
 | 
			
		||||
            and "socket.io" not in msg
 | 
			
		||||
            and "This is a development server." not in msg
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
 | 
			
		||||
| 
						 | 
				
			
			@ -92,6 +98,9 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
 | 
			
		|||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
 | 
			
		||||
app.config["JSON_AS_ASCII"] = False
 | 
			
		||||
CORS(app, expose_headers=["Content-Disposition"])
 | 
			
		||||
 | 
			
		||||
sio_logger = logging.getLogger("sio-logger")
 | 
			
		||||
sio_logger.setLevel(logging.ERROR)
 | 
			
		||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
 | 
			
		||||
 | 
			
		||||
model: ModelManager = None
 | 
			
		||||
| 
						 | 
				
			
			@ -254,6 +263,7 @@ def process():
 | 
			
		|||
        p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
 | 
			
		||||
        p2p_guidance_scale=form["p2pGuidanceScale"],
 | 
			
		||||
        controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
 | 
			
		||||
        controlnet_method=form["controlnet_method"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if config.sd_seed == -1:
 | 
			
		||||
| 
						 | 
				
			
			@ -263,7 +273,6 @@ def process():
 | 
			
		|||
 | 
			
		||||
    logger.info(f"Origin image shape: {original_shape}")
 | 
			
		||||
    image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
 | 
			
		||||
    logger.info(f"Resized image shape: {image.shape}")
 | 
			
		||||
 | 
			
		||||
    mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -436,17 +445,6 @@ def switch_model():
 | 
			
		|||
    return f"ok, switch to {new_name}", 200
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.route("/controlnet_method", methods=["POST"])
 | 
			
		||||
def switch_controlnet_method():
 | 
			
		||||
    new_method = request.form.get("method")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        model.switch_controlnet_method(new_method)
 | 
			
		||||
    except NotImplementedError:
 | 
			
		||||
        return f"Failed switch to {new_method} not implemented", 500
 | 
			
		||||
    return f"Switch to {new_method}", 200
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.route("/")
 | 
			
		||||
def index():
 | 
			
		||||
    return send_file(os.path.join(BUILD_DIR, "index.html"))
 | 
			
		||||
| 
						 | 
				
			
			@ -603,4 +601,10 @@ def main(args):
 | 
			
		|||
        )
 | 
			
		||||
        ui.run()
 | 
			
		||||
    else:
 | 
			
		||||
        socketio.run(app, host=args.host, port=args.port, debug=args.debug)
 | 
			
		||||
        socketio.run(
 | 
			
		||||
            app,
 | 
			
		||||
            host=args.host,
 | 
			
		||||
            port=args.port,
 | 
			
		||||
            debug=args.debug,
 | 
			
		||||
            allow_unsafe_werkzeug=True,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -125,7 +125,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
 | 
			
		|||
        prompt="a fox sitting on a bench",
 | 
			
		||||
        sd_steps=sd_steps,
 | 
			
		||||
        controlnet_conditioning_scale=1.0,
 | 
			
		||||
        sd_strength=1.0
 | 
			
		||||
        sd_strength=1.0,
 | 
			
		||||
    )
 | 
			
		||||
    cfg.sd_sampler = sampler
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -138,3 +138,42 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
 | 
			
		|||
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
 | 
			
		||||
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
 | 
			
		||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
 | 
			
		||||
def test_controlnet_switch(sd_device, sampler):
 | 
			
		||||
    if sd_device == "cuda" and not torch.cuda.is_available():
 | 
			
		||||
        return
 | 
			
		||||
    if device == "mps" and not torch.backends.mps.is_available():
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    sd_steps = 1 if sd_device == "cpu" else 30
 | 
			
		||||
    model = ModelManager(
 | 
			
		||||
        name="sd1.5",
 | 
			
		||||
        sd_controlnet=True,
 | 
			
		||||
        device=torch.device(sd_device),
 | 
			
		||||
        hf_access_token="",
 | 
			
		||||
        sd_run_local=False,
 | 
			
		||||
        disable_nsfw=True,
 | 
			
		||||
        sd_cpu_textencoder=False,
 | 
			
		||||
        cpu_offload=True,
 | 
			
		||||
        sd_controlnet_method="control_v11p_sd15_canny",
 | 
			
		||||
    )
 | 
			
		||||
    cfg = get_config(
 | 
			
		||||
        HDStrategy.ORIGINAL,
 | 
			
		||||
        prompt="a fox sitting on a bench",
 | 
			
		||||
        sd_steps=sd_steps,
 | 
			
		||||
        controlnet_method="control_v11p_sd15_inpaint",
 | 
			
		||||
    )
 | 
			
		||||
    cfg.sd_sampler = sampler
 | 
			
		||||
 | 
			
		||||
    name = f"device_{sd_device}_{sampler}"
 | 
			
		||||
 | 
			
		||||
    assert_equal(
 | 
			
		||||
        model,
 | 
			
		||||
        cfg,
 | 
			
		||||
        f"sd_controlnet_switch_to_inpaint_local_model_{name}.png",
 | 
			
		||||
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
 | 
			
		||||
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,6 +35,7 @@ def _save(img, name):
 | 
			
		|||
def test_remove_bg():
 | 
			
		||||
    model = RemoveBG()
 | 
			
		||||
    res = model.forward(bgr_img)
 | 
			
		||||
    res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
 | 
			
		||||
    _save(res, "test_remove_bg.png")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@ def save_config(
 | 
			
		|||
    model,
 | 
			
		||||
    sd_local_model_path,
 | 
			
		||||
    sd_controlnet,
 | 
			
		||||
    sd_controlnet_method,
 | 
			
		||||
    device,
 | 
			
		||||
    gui,
 | 
			
		||||
    no_gui_auto_close,
 | 
			
		||||
| 
						 | 
				
			
			@ -182,6 +183,11 @@ def main(config_file: str):
 | 
			
		|||
                sd_controlnet = gr.Checkbox(
 | 
			
		||||
                    init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
 | 
			
		||||
                )
 | 
			
		||||
                sd_controlnet_method = gr.Radio(
 | 
			
		||||
                    SD_CONTROLNET_CHOICES,
 | 
			
		||||
                    lable="ControlNet method",
 | 
			
		||||
                    value=init_config.sd_controlnet_method,
 | 
			
		||||
                )
 | 
			
		||||
                no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
 | 
			
		||||
                cpu_offload = gr.Checkbox(
 | 
			
		||||
                    init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
 | 
			
		||||
| 
						 | 
				
			
			@ -207,6 +213,7 @@ def main(config_file: str):
 | 
			
		|||
                model,
 | 
			
		||||
                sd_local_model_path,
 | 
			
		||||
                sd_controlnet,
 | 
			
		||||
                sd_controlnet_method,
 | 
			
		||||
                device,
 | 
			
		||||
                gui,
 | 
			
		||||
                no_gui_auto_close,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@ torch>=1.9.0
 | 
			
		|||
opencv-python
 | 
			
		||||
flask==2.2.3
 | 
			
		||||
flask-socketio
 | 
			
		||||
simple-websocket
 | 
			
		||||
flask_cors
 | 
			
		||||
flaskwebgui==0.3.5
 | 
			
		||||
pydantic
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue