""" Option 1 — Tách nền + Ghép Frame/Watermark cho ảnh sản phẩm Stack: Flask + rembg (U^2-Net ONNX, local) + Pillow + OpenCV Kiến trúc 2 giai đoạn: • Tab 1 (Tách nền): xóa nền → object 1:1 trong suốt → LƯU vào thư viện (trả name). • Tab 2 (Ghép Frame): chọn object + frame + watermark (danh sách) → SAVE render ra file. """ import io import os import re import json import time import uuid import hashlib import threading from concurrent.futures import ThreadPoolExecutor from flask import Flask, request, jsonify, send_from_directory, send_file from rembg import remove, new_session from PIL import Image, ImageOps, ImageEnhance, ImageFilter, ImageChops, ImageDraw BASE = os.path.dirname(os.path.abspath(__file__)) ASSETS = os.path.join(BASE, "assets") # assets/frame/*.png, assets/watermark/*.png OBJECTS = os.path.join(BASE, "objects") # object đã tách nền (thư viện Tab 1) OUTPUT = os.path.join(BASE, "output") CACHE = os.path.join(BASE, "cache") # cutout sau remove-bg, tái dùng for d in (os.path.join(ASSETS, "frame"), os.path.join(ASSETS, "watermark"), OBJECTS, OUTPUT, CACHE): os.makedirs(d, exist_ok=True) app = Flask(__name__, static_folder="static", static_url_path="") # ============================ MODEL / SESSION ============================ SESSIONS = {} SESSION_LOCK = threading.Lock() ALLOWED_MODELS = {"u2net", "isnet-general-use", "birefnet-general-lite"} COMPARE_PRESETS = { "low": {"model": "u2net", "anti_blowout": False, "recover": 0}, "medium": {"model": "isnet-general-use", "anti_blowout": True, "recover": 2}, "high": {"model": "birefnet-general-lite", "anti_blowout": True, "recover": 3}, } def get_session(model: str): if model not in ALLOWED_MODELS: model = "u2net" with SESSION_LOCK: if model not in SESSIONS: print(f"[model] đang khởi tạo '{model}'…", flush=True) t = time.time() # Ép CPUExecutionProvider: CoreML trên Apple Silicon biên dịch u2net → treo. try: SESSIONS[model] = new_session(model, providers=["CPUExecutionProvider"]) except TypeError: os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CPUExecutionProvider]" SESSIONS[model] = new_session(model) print(f"[model] '{model}' sẵn sàng sau {time.time() - t:.1f}s", flush=True) return SESSIONS[model] # ============================ REMOVE BACKGROUND ============================ def remove_object(src, opt): """Xóa nền, tối ưu chống cháy sáng để không mất góc/chi tiết mép.""" rgba = src.convert("RGBA") session = get_session(opt["model"]) if not opt["anti_blowout"]: return remove(rgba, session=session, post_process_mask=True) work = ImageEnhance.Contrast(src.convert("RGB")).enhance(1.5) work = ImageEnhance.Brightness(work).enhance(0.85) mask = remove(work, session=session, only_mask=True, post_process_mask=True) r = int(opt["recover"]) if r > 0: k = r * 2 + 1 mask = mask.filter(ImageFilter.MaxFilter(k)).filter(ImageFilter.MinFilter(k)) out = rgba.copy() out.putalpha(mask) return out def get_cut(src, img_hash, opt): """Cutout sau xóa nền, tái dùng cache theo (ảnh + model/anti/recover).""" key = f"{img_hash}-{opt['model']}-{int(opt['anti_blowout'])}-{opt['recover']}" path = os.path.join(CACHE, key + ".png") if os.path.exists(path): return Image.open(path).convert("RGBA"), True cut = remove_object(src, opt) cut.save(path) return cut, False def fit_square(cut, obj_scale, bg): """Crop sát object rồi đặt vào canvas vuông 1:1.""" bbox = cut.getbbox() if bbox: cut = cut.crop(bbox) w, h = cut.size side = max(1, int(round(max(w, h) / max(0.05, obj_scale)))) fill = (255, 255, 255, 255) if bg == "white" else (0, 0, 0, 0) canvas = Image.new("RGBA", (side, side), fill) canvas.alpha_composite(cut, ((side - w) // 2, (side - h) // 2)) return canvas def object_polygon(cut, max_pts=48): """Đường viền (polygon chuẩn hoá 0..1) ôm sát object, suy từ alpha.""" W, H = cut.size rect = [[0, 0], [1, 0], [1, 1], [0, 1]] try: import cv2 import numpy as np except Exception: # noqa: BLE001 bb = cut.getbbox() if not bb: return rect x0, y0, x1, y1 = bb return [[x0 / W, y0 / H], [x1 / W, y0 / H], [x1 / W, y1 / H], [x0 / W, y1 / H]] alpha = np.array(cut.getchannel("A")) _, binimg = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY) cnts, _ = cv2.findContours(binimg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not cnts: return rect c = max(cnts, key=cv2.contourArea) peri = cv2.arcLength(c, True) eps = 0.008 * peri approx = cv2.approxPolyDP(c, eps, True) while len(approx) > max_pts: eps *= 1.3 approx = cv2.approxPolyDP(c, eps, True) if len(approx) < 3: return rect return [[float(p[0][0]) / W, float(p[0][1]) / H] for p in approx] def apply_poly(cut, poly): """Giữ alpha bên trong polygon (chuẩn hoá), xóa ngoài.""" W, H = cut.size pts = [(p[0] * W, p[1] * H) for p in poly] mask = Image.new("L", (W, H), 0) ImageDraw.Draw(mask).polygon(pts, fill=255) out = cut.copy() out.putalpha(ImageChops.multiply(cut.getchannel("A"), mask)) return out def position(pos, canvas, item, margin): cw, ch = canvas iw, ih = item table = { "northwest": (margin, margin), "north": ((cw - iw) // 2, margin), "northeast": (cw - iw - margin, margin), "center": ((cw - iw) // 2, (ch - ih) // 2), "southwest": (margin, ch - ih - margin), "south": ((cw - iw) // 2, ch - ih - margin), "southeast": (cw - iw - margin, ch - ih - margin), } return table.get(pos, table["southeast"]) # ============================ TAB 1: TÁCH NỀN ============================ def make_object(src, img_hash, opt): """Xóa nền → (polygon crop nếu có) → object vuông 1:1 TRONG SUỐT (không frame/wm).""" cut, cached = get_cut(src, img_hash, opt) polygon = object_polygon(cut) if opt.get("poly"): cut = apply_poly(cut, opt["poly"]) canvas = fit_square(cut, 1.0, "transparent") # lưu ôm sát; tỉ lệ chỉnh ở bước ghép frame out_name = f"{uuid.uuid4().hex[:12]}.png" canvas.save(os.path.join(OUTPUT, out_name)) return out_name, cached, polygon def parse_poly(form): raw = form.get("poly") if not raw: return None try: pts = [[max(0.0, min(1.0, float(x))), max(0.0, min(1.0, float(y)))] for x, y in json.loads(raw)] except (ValueError, TypeError): return None return pts if len(pts) >= 3 else None def build_opt(form): return { "model": form.get("model", "u2net"), "anti_blowout": form.get("anti_blowout", "true") == "true", "recover": int(form.get("recover", 2) or 0), "poly": parse_poly(form), } def save_original(data, filename): ext = os.path.splitext(filename)[1].lower() if ext not in (".jpg", ".jpeg", ".png", ".webp"): ext = ".png" name = f"orig-{uuid.uuid4().hex[:12]}{ext}" with open(os.path.join(OUTPUT, name), "wb") as fh: fh.write(data) return name def run_compare(src, img_hash, filename, original, base_opt): def run(level): opt = dict(base_opt, **COMPARE_PRESETS[level]) t = time.time() out, cached, polygon = make_object(src, img_hash, opt) print(f"[compare] {filename} · {level} ✓ {'cache' if cached else f'{time.time()-t:.1f}s'}", flush=True) return {"name": filename, "output": out, "original": original, "ok": True, "level": level, "cached": cached, "polygon": polygon} levels = ("low", "medium", "high") with ThreadPoolExecutor(max_workers=3) as ex: futures = {lv: ex.submit(run, lv) for lv in levels} out = [] for lv in levels: try: out.append(futures[lv].result()) except Exception as e: # noqa: BLE001 print(f"[compare] {filename} · {lv} ✗ {e}", flush=True) out.append({"name": filename, "ok": False, "level": lv, "error": str(e)}) return out @app.post("/api/process") def process(): opt = build_opt(request.form) compare = request.form.get("compare", "false") == "true" files = request.files.getlist("images") if not files: return jsonify(error="chưa chọn ảnh"), 400 total = len(files) print(f"[batch] nhận {total} ảnh · {'so sánh 3 mức' if compare else opt['model']}", flush=True) results = [] for i, f in enumerate(files, 1): t = time.time() try: data = f.read() img_hash = hashlib.sha1(data).hexdigest()[:16] original = save_original(data, f.filename) src = ImageOps.exif_transpose(Image.open(io.BytesIO(data))).convert("RGBA") if compare: results.extend(run_compare(src, img_hash, f.filename, original, opt)) else: out, cached, polygon = make_object(src, img_hash, opt) results.append({"name": f.filename, "output": out, "original": original, "ok": True, "level": None, "cached": cached, "polygon": polygon}) print(f"[{i}/{total}] {f.filename} ✓ {time.time() - t:.1f}s", flush=True) except Exception as e: # noqa: BLE001 print(f"[{i}/{total}] {f.filename} ✗ {e}", flush=True) results.append({"name": f.filename, "ok": False, "error": str(e)}) print("[batch] hoàn tất.", flush=True) return jsonify(results=results) @app.post("/api/recrop") def recrop(): """Render lại 1 object với vùng chọn polygon, dùng lại ảnh gốc + cache cutout.""" original = request.form.get("original", "") path = os.path.join(OUTPUT, os.path.basename(original)) if not original or not os.path.exists(path): return jsonify(error="không tìm thấy ảnh gốc"), 404 opt = build_opt(request.form) level = request.form.get("level", "") if level in COMPARE_PRESETS: opt.update(COMPARE_PRESETS[level]) with open(path, "rb") as fh: data = fh.read() img_hash = hashlib.sha1(data).hexdigest()[:16] src = ImageOps.exif_transpose(Image.open(io.BytesIO(data))).convert("RGBA") out, cached, polygon = make_object(src, img_hash, opt) return jsonify(ok=True, output=out, original=original, level=level or None, cached=cached, polygon=polygon) # ============================ THƯ VIỆN OBJECT ============================ def safe_name(name): name = re.sub(r"[^\w\-]+", "_", (name or "").strip()) return name[:60] or uuid.uuid4().hex[:8] @app.post("/api/save-object") def save_object(): out = os.path.basename(request.form.get("output", "")) src_path = os.path.join(OUTPUT, out) if not out or not os.path.exists(src_path): return jsonify(error="không tìm thấy ảnh đã xử lý"), 404 name = safe_name(request.form.get("name")) Image.open(src_path).convert("RGBA").save(os.path.join(OBJECTS, name + ".png")) print(f"[object] đã lưu '{name}'", flush=True) return jsonify(ok=True, name=name, url=f"/objects/{name}.png") @app.get("/api/objects") def list_objects(): items = [{"name": fn[:-4], "url": f"/objects/{fn}"} for fn in sorted(os.listdir(OBJECTS)) if fn.endswith(".png")] return jsonify(items=items) @app.delete("/api/objects/") def del_object(name): p = os.path.join(OBJECTS, os.path.basename(name) + ".png") if os.path.exists(p): os.remove(p) return jsonify(ok=True) @app.get("/objects/") def serve_object(name): return send_from_directory(OBJECTS, name) # ============================ TAB 2: FRAME / WATERMARK ============================ def bg_color(bg): """Trả (r,g,b,255) cho mã hex, hoặc None nếu trong suốt.""" if not bg or bg == "transparent": return None h = bg.lstrip("#") if len(h) == 3: h = "".join(c * 2 for c in h) try: return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16), 255) except (ValueError, IndexError): return None def asset_file(kind, aid): return os.path.join(ASSETS, kind, f"{os.path.basename(aid)}.png") @app.post("/api/asset") def add_asset(): kind = request.form.get("kind") if kind not in ("frame", "watermark"): return jsonify(error="kind không hợp lệ"), 400 f = request.files.get("file") if not f: return jsonify(error="thiếu file"), 400 aid = uuid.uuid4().hex[:12] Image.open(f.stream).convert("RGBA").save(asset_file(kind, aid)) return jsonify(ok=True, id=aid, url=f"/api/asset/{kind}/{aid}") @app.get("/api/asset/") def list_assets(kind): d = os.path.join(ASSETS, kind) items = [] if os.path.isdir(d): for fn in sorted(os.listdir(d)): if fn.endswith(".png"): aid = fn[:-4] items.append({"id": aid, "url": f"/api/asset/{kind}/{aid}"}) return jsonify(items=items) @app.get("/api/asset//") def get_asset(kind, aid): p = asset_file(kind, aid) if not os.path.exists(p): return jsonify(error="chưa có"), 404 return send_file(p, mimetype="image/png") @app.delete("/api/asset//") def del_asset(kind, aid): p = asset_file(kind, aid) if os.path.exists(p): os.remove(p) return jsonify(ok=True) @app.post("/api/compose") def compose(): """Ghép object + frame + watermark → render & LƯU file, trả filename.""" name = os.path.basename(request.form.get("object", "")) op = os.path.join(OBJECTS, name + ".png") if not name or not os.path.exists(op): return jsonify(error="không tìm thấy object"), 404 obj = Image.open(op).convert("RGBA") obj_scale = float(request.form.get("obj_scale", 100)) / 100 obj = fit_square(obj, obj_scale, "transparent") # đặt lại tỉ lệ object trong khung fill = bg_color(request.form.get("bg", "transparent")) # None = trong suốt border = int(request.form.get("border", 0) or 0) frame_id = request.form.get("frame", "") wm_id = request.form.get("watermark", "") wm_opacity = float(request.form.get("wm_opacity", 60)) / 100 wm_pos = request.form.get("wm_pos", "southeast") wm_scale = float(request.form.get("wm_scale", 25)) / 100 if fill: canvas = Image.new("RGBA", obj.size, fill) canvas.alpha_composite(obj) else: canvas = obj.copy() if border > 0: canvas = ImageOps.expand(canvas, border=border, fill=fill or (0, 0, 0, 0)) if frame_id and os.path.exists(asset_file("frame", frame_id)): frame = Image.open(asset_file("frame", frame_id)).convert("RGBA").resize(canvas.size) canvas.alpha_composite(frame) if wm_id and os.path.exists(asset_file("watermark", wm_id)): wm = Image.open(asset_file("watermark", wm_id)).convert("RGBA") tw = max(1, int(canvas.width * wm_scale)) th = max(1, int(wm.height * tw / wm.width)) wm = wm.resize((tw, th)) wm.putalpha(wm.split()[3].point(lambda a: int(a * wm_opacity))) margin = int(canvas.width * 0.03) x, y = position(wm_pos, canvas.size, wm.size, margin) canvas.alpha_composite(wm, (x, y)) uid = uuid.uuid4().hex[:12] if fill: # nền đặc → JPG out_name = f"{uid}.jpg" canvas.convert("RGB").save(os.path.join(OUTPUT, out_name), quality=92) else: # trong suốt → PNG out_name = f"{uid}.png" canvas.save(os.path.join(OUTPUT, out_name)) print(f"[compose] {name} + frame={frame_id or '-'} + wm={wm_id or '-'} → {out_name}", flush=True) return jsonify(ok=True, output=out_name, url=f"/output/{out_name}") @app.get("/") def index(): return send_from_directory("static", "index.html") @app.get("/output/") def output(name): return send_from_directory(OUTPUT, name) if __name__ == "__main__": print("→ Option 1 (Python/rembg) đang khởi động…", flush=True) get_session("u2net") print("→ Sẵn sàng tại http://localhost:8001", flush=True) app.run(host="0.0.0.0", port=8001, debug=False, threaded=True)