128 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			128 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
import json
 | 
						|
from pathlib import Path
 | 
						|
from typing import Dict, Optional
 | 
						|
 | 
						|
import cv2
 | 
						|
import psutil
 | 
						|
from PIL import Image
 | 
						|
from loguru import logger
 | 
						|
from rich.console import Console
 | 
						|
from rich.progress import (
 | 
						|
    Progress,
 | 
						|
    SpinnerColumn,
 | 
						|
    TimeElapsedColumn,
 | 
						|
    MofNCompleteColumn,
 | 
						|
    TextColumn,
 | 
						|
    BarColumn,
 | 
						|
    TaskProgressColumn,
 | 
						|
)
 | 
						|
 | 
						|
from iopaint.helper import pil_to_bytes
 | 
						|
from iopaint.model.utils import torch_gc
 | 
						|
from iopaint.model_manager import ModelManager
 | 
						|
from iopaint.schema import InpaintRequest
 | 
						|
 | 
						|
 | 
						|
def glob_images(path: Path) -> Dict[str, Path]:
 | 
						|
    # png/jpg/jpeg
 | 
						|
    if path.is_file():
 | 
						|
        return {path.stem: path}
 | 
						|
    elif path.is_dir():
 | 
						|
        res = {}
 | 
						|
        for it in path.glob("*.*"):
 | 
						|
            if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
 | 
						|
                res[it.stem] = it
 | 
						|
        return res
 | 
						|
 | 
						|
 | 
						|
def batch_inpaint(
 | 
						|
    model: str,
 | 
						|
    device,
 | 
						|
    image: Path,
 | 
						|
    mask: Path,
 | 
						|
    output: Path,
 | 
						|
    config: Optional[Path] = None,
 | 
						|
    concat: bool = False,
 | 
						|
):
 | 
						|
    if image.is_dir() and output.is_file():
 | 
						|
        logger.error(
 | 
						|
            f"invalid --output: when image is a directory, output should be a directory"
 | 
						|
        )
 | 
						|
        exit(-1)
 | 
						|
    output.mkdir(parents=True, exist_ok=True)
 | 
						|
 | 
						|
    image_paths = glob_images(image)
 | 
						|
    mask_paths = glob_images(mask)
 | 
						|
    if len(image_paths) == 0:
 | 
						|
        logger.error(f"invalid --image: empty image folder")
 | 
						|
        exit(-1)
 | 
						|
    if len(mask_paths) == 0:
 | 
						|
        logger.error(f"invalid --mask: empty mask folder")
 | 
						|
        exit(-1)
 | 
						|
 | 
						|
    if config is None:
 | 
						|
        inpaint_request = InpaintRequest()
 | 
						|
        logger.info(f"Using default config: {inpaint_request}")
 | 
						|
    else:
 | 
						|
        with open(config, "r", encoding="utf-8") as f:
 | 
						|
            inpaint_request = InpaintRequest(**json.load(f))
 | 
						|
 | 
						|
    model_manager = ModelManager(name=model, device=device)
 | 
						|
    first_mask = list(mask_paths.values())[0]
 | 
						|
 | 
						|
    console = Console()
 | 
						|
 | 
						|
    with Progress(
 | 
						|
        SpinnerColumn(),
 | 
						|
        TextColumn("[progress.description]{task.description}"),
 | 
						|
        BarColumn(),
 | 
						|
        TaskProgressColumn(),
 | 
						|
        MofNCompleteColumn(),
 | 
						|
        TimeElapsedColumn(),
 | 
						|
        console=console,
 | 
						|
        transient=False,
 | 
						|
    ) as progress:
 | 
						|
        task = progress.add_task("Batch processing...", total=len(image_paths))
 | 
						|
        for stem, image_p in image_paths.items():
 | 
						|
            if stem not in mask_paths and mask.is_dir():
 | 
						|
                progress.log(f"mask for {image_p} not found")
 | 
						|
                progress.update(task, advance=1)
 | 
						|
                continue
 | 
						|
            mask_p = mask_paths.get(stem, first_mask)
 | 
						|
 | 
						|
            infos = Image.open(image_p).info
 | 
						|
 | 
						|
            img = cv2.imread(str(image_p))
 | 
						|
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
 | 
						|
            mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
 | 
						|
            if mask_img.shape[:2] != img.shape[:2]:
 | 
						|
                progress.log(
 | 
						|
                    f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
 | 
						|
                )
 | 
						|
                mask_img = cv2.resize(
 | 
						|
                    mask_img,
 | 
						|
                    (img.shape[1], img.shape[0]),
 | 
						|
                    interpolation=cv2.INTER_NEAREST,
 | 
						|
                )
 | 
						|
            mask_img[mask_img >= 127] = 255
 | 
						|
            mask_img[mask_img < 127] = 0
 | 
						|
 | 
						|
            # bgr
 | 
						|
            inpaint_result = model_manager(img, mask_img, inpaint_request)
 | 
						|
            inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
 | 
						|
            if concat:
 | 
						|
                mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
 | 
						|
                inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
 | 
						|
 | 
						|
            img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
 | 
						|
            save_p = output / f"{stem}.png"
 | 
						|
            with open(save_p, "wb") as fw:
 | 
						|
                fw.write(img_bytes)
 | 
						|
 | 
						|
            progress.update(task, advance=1)
 | 
						|
            torch_gc()
 | 
						|
            # pid = psutil.Process().pid
 | 
						|
            # memory_info = psutil.Process(pid).memory_info()
 | 
						|
            # memory_in_mb = memory_info.rss / (1024 * 1024)
 | 
						|
            # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")
 |