444 lines
16 KiB
Python
444 lines
16 KiB
Python
"""
|
|
Option 1 — Tách nền + Ghép Frame/Watermark cho ảnh sản phẩm
|
|
Stack: Flask + rembg (U^2-Net ONNX, local) + Pillow + OpenCV
|
|
|
|
Kiến trúc 2 giai đoạn:
|
|
• Tab 1 (Tách nền): xóa nền → object 1:1 trong suốt → LƯU vào thư viện (trả name).
|
|
• Tab 2 (Ghép Frame): chọn object + frame + watermark (danh sách) → SAVE render ra file.
|
|
"""
|
|
import io
|
|
import os
|
|
import re
|
|
import json
|
|
import time
|
|
import uuid
|
|
import hashlib
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from flask import Flask, request, jsonify, send_from_directory, send_file
|
|
from rembg import remove, new_session
|
|
from PIL import Image, ImageOps, ImageEnhance, ImageFilter, ImageChops, ImageDraw
|
|
|
|
BASE = os.path.dirname(os.path.abspath(__file__))
|
|
ASSETS = os.path.join(BASE, "assets") # assets/frame/*.png, assets/watermark/*.png
|
|
OBJECTS = os.path.join(BASE, "objects") # object đã tách nền (thư viện Tab 1)
|
|
OUTPUT = os.path.join(BASE, "output")
|
|
CACHE = os.path.join(BASE, "cache") # cutout sau remove-bg, tái dùng
|
|
for d in (os.path.join(ASSETS, "frame"), os.path.join(ASSETS, "watermark"), OBJECTS, OUTPUT, CACHE):
|
|
os.makedirs(d, exist_ok=True)
|
|
|
|
app = Flask(__name__, static_folder="static", static_url_path="")
|
|
|
|
# ============================ MODEL / SESSION ============================
|
|
SESSIONS = {}
|
|
SESSION_LOCK = threading.Lock()
|
|
ALLOWED_MODELS = {"u2net", "isnet-general-use", "birefnet-general-lite"}
|
|
COMPARE_PRESETS = {
|
|
"low": {"model": "u2net", "anti_blowout": False, "recover": 0},
|
|
"medium": {"model": "isnet-general-use", "anti_blowout": True, "recover": 2},
|
|
"high": {"model": "birefnet-general-lite", "anti_blowout": True, "recover": 3},
|
|
}
|
|
|
|
|
|
def get_session(model: str):
|
|
if model not in ALLOWED_MODELS:
|
|
model = "u2net"
|
|
with SESSION_LOCK:
|
|
if model not in SESSIONS:
|
|
print(f"[model] đang khởi tạo '{model}'…", flush=True)
|
|
t = time.time()
|
|
# Ép CPUExecutionProvider: CoreML trên Apple Silicon biên dịch u2net → treo.
|
|
try:
|
|
SESSIONS[model] = new_session(model, providers=["CPUExecutionProvider"])
|
|
except TypeError:
|
|
os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CPUExecutionProvider]"
|
|
SESSIONS[model] = new_session(model)
|
|
print(f"[model] '{model}' sẵn sàng sau {time.time() - t:.1f}s", flush=True)
|
|
return SESSIONS[model]
|
|
|
|
|
|
# ============================ REMOVE BACKGROUND ============================
|
|
def remove_object(src, opt):
|
|
"""Xóa nền, tối ưu chống cháy sáng để không mất góc/chi tiết mép."""
|
|
rgba = src.convert("RGBA")
|
|
session = get_session(opt["model"])
|
|
if not opt["anti_blowout"]:
|
|
return remove(rgba, session=session, post_process_mask=True)
|
|
work = ImageEnhance.Contrast(src.convert("RGB")).enhance(1.5)
|
|
work = ImageEnhance.Brightness(work).enhance(0.85)
|
|
mask = remove(work, session=session, only_mask=True, post_process_mask=True)
|
|
r = int(opt["recover"])
|
|
if r > 0:
|
|
k = r * 2 + 1
|
|
mask = mask.filter(ImageFilter.MaxFilter(k)).filter(ImageFilter.MinFilter(k))
|
|
out = rgba.copy()
|
|
out.putalpha(mask)
|
|
return out
|
|
|
|
|
|
def get_cut(src, img_hash, opt):
|
|
"""Cutout sau xóa nền, tái dùng cache theo (ảnh + model/anti/recover)."""
|
|
key = f"{img_hash}-{opt['model']}-{int(opt['anti_blowout'])}-{opt['recover']}"
|
|
path = os.path.join(CACHE, key + ".png")
|
|
if os.path.exists(path):
|
|
return Image.open(path).convert("RGBA"), True
|
|
cut = remove_object(src, opt)
|
|
cut.save(path)
|
|
return cut, False
|
|
|
|
|
|
def fit_square(cut, obj_scale, bg):
|
|
"""Crop sát object rồi đặt vào canvas vuông 1:1."""
|
|
bbox = cut.getbbox()
|
|
if bbox:
|
|
cut = cut.crop(bbox)
|
|
w, h = cut.size
|
|
side = max(1, int(round(max(w, h) / max(0.05, obj_scale))))
|
|
fill = (255, 255, 255, 255) if bg == "white" else (0, 0, 0, 0)
|
|
canvas = Image.new("RGBA", (side, side), fill)
|
|
canvas.alpha_composite(cut, ((side - w) // 2, (side - h) // 2))
|
|
return canvas
|
|
|
|
|
|
def object_polygon(cut, max_pts=48):
|
|
"""Đường viền (polygon chuẩn hoá 0..1) ôm sát object, suy từ alpha."""
|
|
W, H = cut.size
|
|
rect = [[0, 0], [1, 0], [1, 1], [0, 1]]
|
|
try:
|
|
import cv2
|
|
import numpy as np
|
|
except Exception: # noqa: BLE001
|
|
bb = cut.getbbox()
|
|
if not bb:
|
|
return rect
|
|
x0, y0, x1, y1 = bb
|
|
return [[x0 / W, y0 / H], [x1 / W, y0 / H], [x1 / W, y1 / H], [x0 / W, y1 / H]]
|
|
alpha = np.array(cut.getchannel("A"))
|
|
_, binimg = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
|
|
cnts, _ = cv2.findContours(binimg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
if not cnts:
|
|
return rect
|
|
c = max(cnts, key=cv2.contourArea)
|
|
peri = cv2.arcLength(c, True)
|
|
eps = 0.008 * peri
|
|
approx = cv2.approxPolyDP(c, eps, True)
|
|
while len(approx) > max_pts:
|
|
eps *= 1.3
|
|
approx = cv2.approxPolyDP(c, eps, True)
|
|
if len(approx) < 3:
|
|
return rect
|
|
return [[float(p[0][0]) / W, float(p[0][1]) / H] for p in approx]
|
|
|
|
|
|
def apply_poly(cut, poly):
|
|
"""Giữ alpha bên trong polygon (chuẩn hoá), xóa ngoài."""
|
|
W, H = cut.size
|
|
pts = [(p[0] * W, p[1] * H) for p in poly]
|
|
mask = Image.new("L", (W, H), 0)
|
|
ImageDraw.Draw(mask).polygon(pts, fill=255)
|
|
out = cut.copy()
|
|
out.putalpha(ImageChops.multiply(cut.getchannel("A"), mask))
|
|
return out
|
|
|
|
|
|
def position(pos, canvas, item, margin):
|
|
cw, ch = canvas
|
|
iw, ih = item
|
|
table = {
|
|
"northwest": (margin, margin),
|
|
"north": ((cw - iw) // 2, margin),
|
|
"northeast": (cw - iw - margin, margin),
|
|
"center": ((cw - iw) // 2, (ch - ih) // 2),
|
|
"southwest": (margin, ch - ih - margin),
|
|
"south": ((cw - iw) // 2, ch - ih - margin),
|
|
"southeast": (cw - iw - margin, ch - ih - margin),
|
|
}
|
|
return table.get(pos, table["southeast"])
|
|
|
|
|
|
# ============================ TAB 1: TÁCH NỀN ============================
|
|
def make_object(src, img_hash, opt):
|
|
"""Xóa nền → (polygon crop nếu có) → object vuông 1:1 TRONG SUỐT (không frame/wm)."""
|
|
cut, cached = get_cut(src, img_hash, opt)
|
|
polygon = object_polygon(cut)
|
|
if opt.get("poly"):
|
|
cut = apply_poly(cut, opt["poly"])
|
|
canvas = fit_square(cut, 1.0, "transparent") # lưu ôm sát; tỉ lệ chỉnh ở bước ghép frame
|
|
out_name = f"{uuid.uuid4().hex[:12]}.png"
|
|
canvas.save(os.path.join(OUTPUT, out_name))
|
|
return out_name, cached, polygon
|
|
|
|
|
|
def parse_poly(form):
|
|
raw = form.get("poly")
|
|
if not raw:
|
|
return None
|
|
try:
|
|
pts = [[max(0.0, min(1.0, float(x))), max(0.0, min(1.0, float(y)))]
|
|
for x, y in json.loads(raw)]
|
|
except (ValueError, TypeError):
|
|
return None
|
|
return pts if len(pts) >= 3 else None
|
|
|
|
|
|
def build_opt(form):
|
|
return {
|
|
"model": form.get("model", "u2net"),
|
|
"anti_blowout": form.get("anti_blowout", "true") == "true",
|
|
"recover": int(form.get("recover", 2) or 0),
|
|
"poly": parse_poly(form),
|
|
}
|
|
|
|
|
|
def save_original(data, filename):
|
|
ext = os.path.splitext(filename)[1].lower()
|
|
if ext not in (".jpg", ".jpeg", ".png", ".webp"):
|
|
ext = ".png"
|
|
name = f"orig-{uuid.uuid4().hex[:12]}{ext}"
|
|
with open(os.path.join(OUTPUT, name), "wb") as fh:
|
|
fh.write(data)
|
|
return name
|
|
|
|
|
|
def run_compare(src, img_hash, filename, original, base_opt):
|
|
def run(level):
|
|
opt = dict(base_opt, **COMPARE_PRESETS[level])
|
|
t = time.time()
|
|
out, cached, polygon = make_object(src, img_hash, opt)
|
|
print(f"[compare] {filename} · {level} ✓ {'cache' if cached else f'{time.time()-t:.1f}s'}", flush=True)
|
|
return {"name": filename, "output": out, "original": original,
|
|
"ok": True, "level": level, "cached": cached, "polygon": polygon}
|
|
|
|
levels = ("low", "medium", "high")
|
|
with ThreadPoolExecutor(max_workers=3) as ex:
|
|
futures = {lv: ex.submit(run, lv) for lv in levels}
|
|
out = []
|
|
for lv in levels:
|
|
try:
|
|
out.append(futures[lv].result())
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"[compare] {filename} · {lv} ✗ {e}", flush=True)
|
|
out.append({"name": filename, "ok": False, "level": lv, "error": str(e)})
|
|
return out
|
|
|
|
|
|
@app.post("/api/process")
|
|
def process():
|
|
opt = build_opt(request.form)
|
|
compare = request.form.get("compare", "false") == "true"
|
|
files = request.files.getlist("images")
|
|
if not files:
|
|
return jsonify(error="chưa chọn ảnh"), 400
|
|
|
|
total = len(files)
|
|
print(f"[batch] nhận {total} ảnh · {'so sánh 3 mức' if compare else opt['model']}", flush=True)
|
|
results = []
|
|
for i, f in enumerate(files, 1):
|
|
t = time.time()
|
|
try:
|
|
data = f.read()
|
|
img_hash = hashlib.sha1(data).hexdigest()[:16]
|
|
original = save_original(data, f.filename)
|
|
src = ImageOps.exif_transpose(Image.open(io.BytesIO(data))).convert("RGBA")
|
|
if compare:
|
|
results.extend(run_compare(src, img_hash, f.filename, original, opt))
|
|
else:
|
|
out, cached, polygon = make_object(src, img_hash, opt)
|
|
results.append({"name": f.filename, "output": out, "original": original,
|
|
"ok": True, "level": None, "cached": cached, "polygon": polygon})
|
|
print(f"[{i}/{total}] {f.filename} ✓ {time.time() - t:.1f}s", flush=True)
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"[{i}/{total}] {f.filename} ✗ {e}", flush=True)
|
|
results.append({"name": f.filename, "ok": False, "error": str(e)})
|
|
print("[batch] hoàn tất.", flush=True)
|
|
return jsonify(results=results)
|
|
|
|
|
|
@app.post("/api/recrop")
|
|
def recrop():
|
|
"""Render lại 1 object với vùng chọn polygon, dùng lại ảnh gốc + cache cutout."""
|
|
original = request.form.get("original", "")
|
|
path = os.path.join(OUTPUT, os.path.basename(original))
|
|
if not original or not os.path.exists(path):
|
|
return jsonify(error="không tìm thấy ảnh gốc"), 404
|
|
opt = build_opt(request.form)
|
|
level = request.form.get("level", "")
|
|
if level in COMPARE_PRESETS:
|
|
opt.update(COMPARE_PRESETS[level])
|
|
with open(path, "rb") as fh:
|
|
data = fh.read()
|
|
img_hash = hashlib.sha1(data).hexdigest()[:16]
|
|
src = ImageOps.exif_transpose(Image.open(io.BytesIO(data))).convert("RGBA")
|
|
out, cached, polygon = make_object(src, img_hash, opt)
|
|
return jsonify(ok=True, output=out, original=original, level=level or None,
|
|
cached=cached, polygon=polygon)
|
|
|
|
|
|
# ============================ THƯ VIỆN OBJECT ============================
|
|
def safe_name(name):
|
|
name = re.sub(r"[^\w\-]+", "_", (name or "").strip())
|
|
return name[:60] or uuid.uuid4().hex[:8]
|
|
|
|
|
|
@app.post("/api/save-object")
|
|
def save_object():
|
|
out = os.path.basename(request.form.get("output", ""))
|
|
src_path = os.path.join(OUTPUT, out)
|
|
if not out or not os.path.exists(src_path):
|
|
return jsonify(error="không tìm thấy ảnh đã xử lý"), 404
|
|
name = safe_name(request.form.get("name"))
|
|
Image.open(src_path).convert("RGBA").save(os.path.join(OBJECTS, name + ".png"))
|
|
print(f"[object] đã lưu '{name}'", flush=True)
|
|
return jsonify(ok=True, name=name, url=f"/objects/{name}.png")
|
|
|
|
|
|
@app.get("/api/objects")
|
|
def list_objects():
|
|
items = [{"name": fn[:-4], "url": f"/objects/{fn}"}
|
|
for fn in sorted(os.listdir(OBJECTS)) if fn.endswith(".png")]
|
|
return jsonify(items=items)
|
|
|
|
|
|
@app.delete("/api/objects/<name>")
|
|
def del_object(name):
|
|
p = os.path.join(OBJECTS, os.path.basename(name) + ".png")
|
|
if os.path.exists(p):
|
|
os.remove(p)
|
|
return jsonify(ok=True)
|
|
|
|
|
|
@app.get("/objects/<path:name>")
|
|
def serve_object(name):
|
|
return send_from_directory(OBJECTS, name)
|
|
|
|
|
|
# ============================ TAB 2: FRAME / WATERMARK ============================
|
|
def bg_color(bg):
|
|
"""Trả (r,g,b,255) cho mã hex, hoặc None nếu trong suốt."""
|
|
if not bg or bg == "transparent":
|
|
return None
|
|
h = bg.lstrip("#")
|
|
if len(h) == 3:
|
|
h = "".join(c * 2 for c in h)
|
|
try:
|
|
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16), 255)
|
|
except (ValueError, IndexError):
|
|
return None
|
|
|
|
|
|
def asset_file(kind, aid):
|
|
return os.path.join(ASSETS, kind, f"{os.path.basename(aid)}.png")
|
|
|
|
|
|
@app.post("/api/asset")
|
|
def add_asset():
|
|
kind = request.form.get("kind")
|
|
if kind not in ("frame", "watermark"):
|
|
return jsonify(error="kind không hợp lệ"), 400
|
|
f = request.files.get("file")
|
|
if not f:
|
|
return jsonify(error="thiếu file"), 400
|
|
aid = uuid.uuid4().hex[:12]
|
|
Image.open(f.stream).convert("RGBA").save(asset_file(kind, aid))
|
|
return jsonify(ok=True, id=aid, url=f"/api/asset/{kind}/{aid}")
|
|
|
|
|
|
@app.get("/api/asset/<kind>")
|
|
def list_assets(kind):
|
|
d = os.path.join(ASSETS, kind)
|
|
items = []
|
|
if os.path.isdir(d):
|
|
for fn in sorted(os.listdir(d)):
|
|
if fn.endswith(".png"):
|
|
aid = fn[:-4]
|
|
items.append({"id": aid, "url": f"/api/asset/{kind}/{aid}"})
|
|
return jsonify(items=items)
|
|
|
|
|
|
@app.get("/api/asset/<kind>/<aid>")
|
|
def get_asset(kind, aid):
|
|
p = asset_file(kind, aid)
|
|
if not os.path.exists(p):
|
|
return jsonify(error="chưa có"), 404
|
|
return send_file(p, mimetype="image/png")
|
|
|
|
|
|
@app.delete("/api/asset/<kind>/<aid>")
|
|
def del_asset(kind, aid):
|
|
p = asset_file(kind, aid)
|
|
if os.path.exists(p):
|
|
os.remove(p)
|
|
return jsonify(ok=True)
|
|
|
|
|
|
@app.post("/api/compose")
|
|
def compose():
|
|
"""Ghép object + frame + watermark → render & LƯU file, trả filename."""
|
|
name = os.path.basename(request.form.get("object", ""))
|
|
op = os.path.join(OBJECTS, name + ".png")
|
|
if not name or not os.path.exists(op):
|
|
return jsonify(error="không tìm thấy object"), 404
|
|
obj = Image.open(op).convert("RGBA")
|
|
|
|
obj_scale = float(request.form.get("obj_scale", 100)) / 100
|
|
obj = fit_square(obj, obj_scale, "transparent") # đặt lại tỉ lệ object trong khung
|
|
|
|
fill = bg_color(request.form.get("bg", "transparent")) # None = trong suốt
|
|
border = int(request.form.get("border", 0) or 0)
|
|
frame_id = request.form.get("frame", "")
|
|
wm_id = request.form.get("watermark", "")
|
|
wm_opacity = float(request.form.get("wm_opacity", 60)) / 100
|
|
wm_pos = request.form.get("wm_pos", "southeast")
|
|
wm_scale = float(request.form.get("wm_scale", 25)) / 100
|
|
|
|
if fill:
|
|
canvas = Image.new("RGBA", obj.size, fill)
|
|
canvas.alpha_composite(obj)
|
|
else:
|
|
canvas = obj.copy()
|
|
|
|
if border > 0:
|
|
canvas = ImageOps.expand(canvas, border=border, fill=fill or (0, 0, 0, 0))
|
|
|
|
if frame_id and os.path.exists(asset_file("frame", frame_id)):
|
|
frame = Image.open(asset_file("frame", frame_id)).convert("RGBA").resize(canvas.size)
|
|
canvas.alpha_composite(frame)
|
|
|
|
if wm_id and os.path.exists(asset_file("watermark", wm_id)):
|
|
wm = Image.open(asset_file("watermark", wm_id)).convert("RGBA")
|
|
tw = max(1, int(canvas.width * wm_scale))
|
|
th = max(1, int(wm.height * tw / wm.width))
|
|
wm = wm.resize((tw, th))
|
|
wm.putalpha(wm.split()[3].point(lambda a: int(a * wm_opacity)))
|
|
margin = int(canvas.width * 0.03)
|
|
x, y = position(wm_pos, canvas.size, wm.size, margin)
|
|
canvas.alpha_composite(wm, (x, y))
|
|
|
|
uid = uuid.uuid4().hex[:12]
|
|
if fill: # nền đặc → JPG
|
|
out_name = f"{uid}.jpg"
|
|
canvas.convert("RGB").save(os.path.join(OUTPUT, out_name), quality=92)
|
|
else: # trong suốt → PNG
|
|
out_name = f"{uid}.png"
|
|
canvas.save(os.path.join(OUTPUT, out_name))
|
|
print(f"[compose] {name} + frame={frame_id or '-'} + wm={wm_id or '-'} → {out_name}", flush=True)
|
|
return jsonify(ok=True, output=out_name, url=f"/output/{out_name}")
|
|
|
|
|
|
@app.get("/")
|
|
def index():
|
|
return send_from_directory("static", "index.html")
|
|
|
|
|
|
@app.get("/output/<path:name>")
|
|
def output(name):
|
|
return send_from_directory(OUTPUT, name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("→ Option 1 (Python/rembg) đang khởi động…", flush=True)
|
|
get_session("u2net")
|
|
print("→ Sẵn sàng tại http://localhost:8001", flush=True)
|
|
app.run(host="0.0.0.0", port=8001, debug=False, threaded=True)
|