404 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			404 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
"""
 | 
						|
AnyText: Multilingual Visual Text Generation And Editing
 | 
						|
Paper: https://arxiv.org/abs/2311.03054
 | 
						|
Code: https://github.com/tyxsspa/AnyText
 | 
						|
Copyright (c) Alibaba, Inc. and its affiliates.
 | 
						|
"""
 | 
						|
import os
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
from iopaint.model.utils import set_seed
 | 
						|
from safetensors.torch import load_file
 | 
						|
 | 
						|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 | 
						|
import torch
 | 
						|
import re
 | 
						|
import numpy as np
 | 
						|
import cv2
 | 
						|
import einops
 | 
						|
from PIL import ImageFont
 | 
						|
from iopaint.model.anytext.cldm.model import create_model, load_state_dict
 | 
						|
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
 | 
						|
from iopaint.model.anytext.utils import (
 | 
						|
    check_channels,
 | 
						|
    draw_glyph,
 | 
						|
    draw_glyph2,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
BBOX_MAX_NUM = 8
 | 
						|
PLACE_HOLDER = "*"
 | 
						|
max_chars = 20
 | 
						|
 | 
						|
ANYTEXT_CFG = os.path.join(
 | 
						|
    os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def check_limits(tensor):
 | 
						|
    float16_min = torch.finfo(torch.float16).min
 | 
						|
    float16_max = torch.finfo(torch.float16).max
 | 
						|
 | 
						|
    # 检查张量中是否有值小于float16的最小值或大于float16的最大值
 | 
						|
    is_below_min = (tensor < float16_min).any()
 | 
						|
    is_above_max = (tensor > float16_max).any()
 | 
						|
 | 
						|
    return is_below_min or is_above_max
 | 
						|
 | 
						|
 | 
						|
class AnyTextPipeline:
 | 
						|
    def __init__(self, ckpt_path, font_path, device, use_fp16=True):
 | 
						|
        self.cfg_path = ANYTEXT_CFG
 | 
						|
        self.font_path = font_path
 | 
						|
        self.use_fp16 = use_fp16
 | 
						|
        self.device = device
 | 
						|
 | 
						|
        self.font = ImageFont.truetype(font_path, size=60)
 | 
						|
        self.model = create_model(
 | 
						|
            self.cfg_path,
 | 
						|
            device=self.device,
 | 
						|
            use_fp16=self.use_fp16,
 | 
						|
        )
 | 
						|
        if self.use_fp16:
 | 
						|
            self.model = self.model.half()
 | 
						|
        if Path(ckpt_path).suffix == ".safetensors":
 | 
						|
            state_dict = load_file(ckpt_path, device="cpu")
 | 
						|
        else:
 | 
						|
            state_dict = load_state_dict(ckpt_path, location="cpu")
 | 
						|
        self.model.load_state_dict(state_dict, strict=False)
 | 
						|
        self.model = self.model.eval().to(self.device)
 | 
						|
        self.ddim_sampler = DDIMSampler(self.model, device=self.device)
 | 
						|
 | 
						|
    def __call__(
 | 
						|
        self,
 | 
						|
        prompt: str,
 | 
						|
        negative_prompt: str,
 | 
						|
        image: np.ndarray,
 | 
						|
        masked_image: np.ndarray,
 | 
						|
        num_inference_steps: int,
 | 
						|
        strength: float,
 | 
						|
        guidance_scale: float,
 | 
						|
        height: int,
 | 
						|
        width: int,
 | 
						|
        seed: int,
 | 
						|
        sort_priority: str = "y",
 | 
						|
        callback=None,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
 | 
						|
        Args:
 | 
						|
            prompt:
 | 
						|
            negative_prompt:
 | 
						|
            image:
 | 
						|
            masked_image:
 | 
						|
            num_inference_steps:
 | 
						|
            strength:
 | 
						|
            guidance_scale:
 | 
						|
            height:
 | 
						|
            width:
 | 
						|
            seed:
 | 
						|
            sort_priority: x: left-right, y: top-down
 | 
						|
 | 
						|
        Returns:
 | 
						|
            result: list of images in numpy.ndarray format
 | 
						|
            rst_code: 0: normal -1: error 1:warning
 | 
						|
            rst_info: string of error or warning
 | 
						|
 | 
						|
        """
 | 
						|
        set_seed(seed)
 | 
						|
        str_warning = ""
 | 
						|
 | 
						|
        mode = "text-editing"
 | 
						|
        revise_pos = False
 | 
						|
        img_count = 1
 | 
						|
        ddim_steps = num_inference_steps
 | 
						|
        w = width
 | 
						|
        h = height
 | 
						|
        strength = strength
 | 
						|
        cfg_scale = guidance_scale
 | 
						|
        eta = 0.0
 | 
						|
 | 
						|
        prompt, texts = self.modify_prompt(prompt)
 | 
						|
        if prompt is None and texts is None:
 | 
						|
            return (
 | 
						|
                None,
 | 
						|
                -1,
 | 
						|
                "You have input Chinese prompt but the translator is not loaded!",
 | 
						|
                "",
 | 
						|
            )
 | 
						|
        n_lines = len(texts)
 | 
						|
        if mode in ["text-generation", "gen"]:
 | 
						|
            edit_image = np.ones((h, w, 3)) * 127.5  # empty mask image
 | 
						|
        elif mode in ["text-editing", "edit"]:
 | 
						|
            if masked_image is None or image is None:
 | 
						|
                return (
 | 
						|
                    None,
 | 
						|
                    -1,
 | 
						|
                    "Reference image and position image are needed for text editing!",
 | 
						|
                    "",
 | 
						|
                )
 | 
						|
            if isinstance(image, str):
 | 
						|
                image = cv2.imread(image)[..., ::-1]
 | 
						|
                assert image is not None, f"Can't read ori_image image from{image}!"
 | 
						|
            elif isinstance(image, torch.Tensor):
 | 
						|
                image = image.cpu().numpy()
 | 
						|
            else:
 | 
						|
                assert isinstance(
 | 
						|
                    image, np.ndarray
 | 
						|
                ), f"Unknown format of ori_image: {type(image)}"
 | 
						|
            edit_image = image.clip(1, 255)  # for mask reason
 | 
						|
            edit_image = check_channels(edit_image)
 | 
						|
            # edit_image = resize_image(
 | 
						|
            #     edit_image, max_length=768
 | 
						|
            # )  # make w h multiple of 64, resize if w or h > max_length
 | 
						|
            h, w = edit_image.shape[:2]  # change h, w by input ref_img
 | 
						|
        # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
 | 
						|
        if masked_image is None:
 | 
						|
            pos_imgs = np.zeros((w, h, 1))
 | 
						|
        if isinstance(masked_image, str):
 | 
						|
            masked_image = cv2.imread(masked_image)[..., ::-1]
 | 
						|
            assert (
 | 
						|
                masked_image is not None
 | 
						|
            ), f"Can't read draw_pos image from{masked_image}!"
 | 
						|
            pos_imgs = 255 - masked_image
 | 
						|
        elif isinstance(masked_image, torch.Tensor):
 | 
						|
            pos_imgs = masked_image.cpu().numpy()
 | 
						|
        else:
 | 
						|
            assert isinstance(
 | 
						|
                masked_image, np.ndarray
 | 
						|
            ), f"Unknown format of draw_pos: {type(masked_image)}"
 | 
						|
            pos_imgs = 255 - masked_image
 | 
						|
        pos_imgs = pos_imgs[..., 0:1]
 | 
						|
        pos_imgs = cv2.convertScaleAbs(pos_imgs)
 | 
						|
        _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
 | 
						|
        # seprate pos_imgs
 | 
						|
        pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
 | 
						|
        if len(pos_imgs) == 0:
 | 
						|
            pos_imgs = [np.zeros((h, w, 1))]
 | 
						|
        if len(pos_imgs) < n_lines:
 | 
						|
            if n_lines == 1 and texts[0] == " ":
 | 
						|
                pass  # text-to-image without text
 | 
						|
            else:
 | 
						|
                raise RuntimeError(
 | 
						|
                    f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
 | 
						|
                )
 | 
						|
        elif len(pos_imgs) > n_lines:
 | 
						|
            str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
 | 
						|
        # get pre_pos, poly_list, hint that needed for anytext
 | 
						|
        pre_pos = []
 | 
						|
        poly_list = []
 | 
						|
        for input_pos in pos_imgs:
 | 
						|
            if input_pos.mean() != 0:
 | 
						|
                input_pos = (
 | 
						|
                    input_pos[..., np.newaxis]
 | 
						|
                    if len(input_pos.shape) == 2
 | 
						|
                    else input_pos
 | 
						|
                )
 | 
						|
                poly, pos_img = self.find_polygon(input_pos)
 | 
						|
                pre_pos += [pos_img / 255.0]
 | 
						|
                poly_list += [poly]
 | 
						|
            else:
 | 
						|
                pre_pos += [np.zeros((h, w, 1))]
 | 
						|
                poly_list += [None]
 | 
						|
        np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
 | 
						|
        # prepare info dict
 | 
						|
        info = {}
 | 
						|
        info["glyphs"] = []
 | 
						|
        info["gly_line"] = []
 | 
						|
        info["positions"] = []
 | 
						|
        info["n_lines"] = [len(texts)] * img_count
 | 
						|
        gly_pos_imgs = []
 | 
						|
        for i in range(len(texts)):
 | 
						|
            text = texts[i]
 | 
						|
            if len(text) > max_chars:
 | 
						|
                str_warning = (
 | 
						|
                    f'"{text}" length > max_chars: {max_chars}, will be cut off...'
 | 
						|
                )
 | 
						|
                text = text[:max_chars]
 | 
						|
            gly_scale = 2
 | 
						|
            if pre_pos[i].mean() != 0:
 | 
						|
                gly_line = draw_glyph(self.font, text)
 | 
						|
                glyphs = draw_glyph2(
 | 
						|
                    self.font,
 | 
						|
                    text,
 | 
						|
                    poly_list[i],
 | 
						|
                    scale=gly_scale,
 | 
						|
                    width=w,
 | 
						|
                    height=h,
 | 
						|
                    add_space=False,
 | 
						|
                )
 | 
						|
                gly_pos_img = cv2.drawContours(
 | 
						|
                    glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
 | 
						|
                )
 | 
						|
                if revise_pos:
 | 
						|
                    resize_gly = cv2.resize(
 | 
						|
                        glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
 | 
						|
                    )
 | 
						|
                    new_pos = cv2.morphologyEx(
 | 
						|
                        (resize_gly * 255).astype(np.uint8),
 | 
						|
                        cv2.MORPH_CLOSE,
 | 
						|
                        kernel=np.ones(
 | 
						|
                            (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
 | 
						|
                            dtype=np.uint8,
 | 
						|
                        ),
 | 
						|
                        iterations=1,
 | 
						|
                    )
 | 
						|
                    new_pos = (
 | 
						|
                        new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
 | 
						|
                    )
 | 
						|
                    contours, _ = cv2.findContours(
 | 
						|
                        new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
 | 
						|
                    )
 | 
						|
                    if len(contours) != 1:
 | 
						|
                        str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
 | 
						|
                    else:
 | 
						|
                        rect = cv2.minAreaRect(contours[0])
 | 
						|
                        poly = np.int0(cv2.boxPoints(rect))
 | 
						|
                        pre_pos[i] = (
 | 
						|
                            cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
 | 
						|
                        )
 | 
						|
                        gly_pos_img = cv2.drawContours(
 | 
						|
                            glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
 | 
						|
                        )
 | 
						|
                gly_pos_imgs += [gly_pos_img]  # for show
 | 
						|
            else:
 | 
						|
                glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
 | 
						|
                gly_line = np.zeros((80, 512, 1))
 | 
						|
                gly_pos_imgs += [
 | 
						|
                    np.zeros((h * gly_scale, w * gly_scale, 1))
 | 
						|
                ]  # for show
 | 
						|
            pos = pre_pos[i]
 | 
						|
            info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
 | 
						|
            info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
 | 
						|
            info["positions"] += [self.arr2tensor(pos, img_count)]
 | 
						|
        # get masked_x
 | 
						|
        masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
 | 
						|
        masked_img = np.transpose(masked_img, (2, 0, 1))
 | 
						|
        masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
 | 
						|
        if self.use_fp16:
 | 
						|
            masked_img = masked_img.half()
 | 
						|
        encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
 | 
						|
        masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
 | 
						|
        if self.use_fp16:
 | 
						|
            masked_x = masked_x.half()
 | 
						|
        info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
 | 
						|
 | 
						|
        hint = self.arr2tensor(np_hint, img_count)
 | 
						|
        cond = self.model.get_learned_conditioning(
 | 
						|
            dict(
 | 
						|
                c_concat=[hint],
 | 
						|
                c_crossattn=[[prompt] * img_count],
 | 
						|
                text_info=info,
 | 
						|
            )
 | 
						|
        )
 | 
						|
        un_cond = self.model.get_learned_conditioning(
 | 
						|
            dict(
 | 
						|
                c_concat=[hint],
 | 
						|
                c_crossattn=[[negative_prompt] * img_count],
 | 
						|
                text_info=info,
 | 
						|
            )
 | 
						|
        )
 | 
						|
        shape = (4, h // 8, w // 8)
 | 
						|
        self.model.control_scales = [strength] * 13
 | 
						|
        samples, intermediates = self.ddim_sampler.sample(
 | 
						|
            ddim_steps,
 | 
						|
            img_count,
 | 
						|
            shape,
 | 
						|
            cond,
 | 
						|
            verbose=False,
 | 
						|
            eta=eta,
 | 
						|
            unconditional_guidance_scale=cfg_scale,
 | 
						|
            unconditional_conditioning=un_cond,
 | 
						|
            callback=callback
 | 
						|
        )
 | 
						|
        if self.use_fp16:
 | 
						|
            samples = samples.half()
 | 
						|
        x_samples = self.model.decode_first_stage(samples)
 | 
						|
        x_samples = (
 | 
						|
            (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
 | 
						|
            .cpu()
 | 
						|
            .numpy()
 | 
						|
            .clip(0, 255)
 | 
						|
            .astype(np.uint8)
 | 
						|
        )
 | 
						|
        results = [x_samples[i] for i in range(img_count)]
 | 
						|
        # if (
 | 
						|
        #     mode == "edit" and False
 | 
						|
        # ):  # replace backgound in text editing but not ideal yet
 | 
						|
        #     results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
 | 
						|
        #     results = [r.clip(0, 255).astype(np.uint8) for r in results]
 | 
						|
        # if len(gly_pos_imgs) > 0 and show_debug:
 | 
						|
        #     glyph_bs = np.stack(gly_pos_imgs, axis=2)
 | 
						|
        #     glyph_img = np.sum(glyph_bs, axis=2) * 255
 | 
						|
        #     glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
 | 
						|
        #     results += [np.repeat(glyph_img, 3, axis=2)]
 | 
						|
        rst_code = 1 if str_warning else 0
 | 
						|
        return results, rst_code, str_warning
 | 
						|
 | 
						|
    def modify_prompt(self, prompt):
 | 
						|
        prompt = prompt.replace("“", '"')
 | 
						|
        prompt = prompt.replace("”", '"')
 | 
						|
        p = '"(.*?)"'
 | 
						|
        strs = re.findall(p, prompt)
 | 
						|
        if len(strs) == 0:
 | 
						|
            strs = [" "]
 | 
						|
        else:
 | 
						|
            for s in strs:
 | 
						|
                prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
 | 
						|
        # if self.is_chinese(prompt):
 | 
						|
        #     if self.trans_pipe is None:
 | 
						|
        #         return None, None
 | 
						|
        #     old_prompt = prompt
 | 
						|
        #     prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
 | 
						|
        #     print(f"Translate: {old_prompt} --> {prompt}")
 | 
						|
        return prompt, strs
 | 
						|
 | 
						|
    # def is_chinese(self, text):
 | 
						|
    #     text = checker._clean_text(text)
 | 
						|
    #     for char in text:
 | 
						|
    #         cp = ord(char)
 | 
						|
    #         if checker._is_chinese_char(cp):
 | 
						|
    #             return True
 | 
						|
    #     return False
 | 
						|
 | 
						|
    def separate_pos_imgs(self, img, sort_priority, gap=102):
 | 
						|
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
 | 
						|
        components = []
 | 
						|
        for label in range(1, num_labels):
 | 
						|
            component = np.zeros_like(img)
 | 
						|
            component[labels == label] = 255
 | 
						|
            components.append((component, centroids[label]))
 | 
						|
        if sort_priority == "y":
 | 
						|
            fir, sec = 1, 0  # top-down first
 | 
						|
        elif sort_priority == "x":
 | 
						|
            fir, sec = 0, 1  # left-right first
 | 
						|
        components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
 | 
						|
        sorted_components = [c[0] for c in components]
 | 
						|
        return sorted_components
 | 
						|
 | 
						|
    def find_polygon(self, image, min_rect=False):
 | 
						|
        contours, hierarchy = cv2.findContours(
 | 
						|
            image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
 | 
						|
        )
 | 
						|
        max_contour = max(contours, key=cv2.contourArea)  # get contour with max area
 | 
						|
        if min_rect:
 | 
						|
            # get minimum enclosing rectangle
 | 
						|
            rect = cv2.minAreaRect(max_contour)
 | 
						|
            poly = np.int0(cv2.boxPoints(rect))
 | 
						|
        else:
 | 
						|
            # get approximate polygon
 | 
						|
            epsilon = 0.01 * cv2.arcLength(max_contour, True)
 | 
						|
            poly = cv2.approxPolyDP(max_contour, epsilon, True)
 | 
						|
            n, _, xy = poly.shape
 | 
						|
            poly = poly.reshape(n, xy)
 | 
						|
        cv2.drawContours(image, [poly], -1, 255, -1)
 | 
						|
        return poly, image
 | 
						|
 | 
						|
    def arr2tensor(self, arr, bs):
 | 
						|
        arr = np.transpose(arr, (2, 0, 1))
 | 
						|
        _arr = torch.from_numpy(arr.copy()).float().to(self.device)
 | 
						|
        if self.use_fp16:
 | 
						|
            _arr = _arr.half()
 | 
						|
        _arr = torch.stack([_arr for _ in range(bs)], dim=0)
 | 
						|
        return _arr
 |