product-image-studio-option1/server.py

444 lines
16 KiB
Python

"""
Option 1 — Tách nền + Ghép Frame/Watermark cho ảnh sản phẩm
Stack: Flask + rembg (U^2-Net ONNX, local) + Pillow + OpenCV
Kiến trúc 2 giai đoạn:
• Tab 1 (Tách nền): xóa nền → object 1:1 trong suốt → LƯU vào thư viện (trả name).
• Tab 2 (Ghép Frame): chọn object + frame + watermark (danh sách) → SAVE render ra file.
"""
import io
import os
import re
import json
import time
import uuid
import hashlib
import threading
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, request, jsonify, send_from_directory, send_file
from rembg import remove, new_session
from PIL import Image, ImageOps, ImageEnhance, ImageFilter, ImageChops, ImageDraw
BASE = os.path.dirname(os.path.abspath(__file__))
ASSETS = os.path.join(BASE, "assets") # assets/frame/*.png, assets/watermark/*.png
OBJECTS = os.path.join(BASE, "objects") # object đã tách nền (thư viện Tab 1)
OUTPUT = os.path.join(BASE, "output")
CACHE = os.path.join(BASE, "cache") # cutout sau remove-bg, tái dùng
for d in (os.path.join(ASSETS, "frame"), os.path.join(ASSETS, "watermark"), OBJECTS, OUTPUT, CACHE):
os.makedirs(d, exist_ok=True)
app = Flask(__name__, static_folder="static", static_url_path="")
# ============================ MODEL / SESSION ============================
SESSIONS = {}
SESSION_LOCK = threading.Lock()
ALLOWED_MODELS = {"u2net", "isnet-general-use", "birefnet-general-lite"}
COMPARE_PRESETS = {
"low": {"model": "u2net", "anti_blowout": False, "recover": 0},
"medium": {"model": "isnet-general-use", "anti_blowout": True, "recover": 2},
"high": {"model": "birefnet-general-lite", "anti_blowout": True, "recover": 3},
}
def get_session(model: str):
if model not in ALLOWED_MODELS:
model = "u2net"
with SESSION_LOCK:
if model not in SESSIONS:
print(f"[model] đang khởi tạo '{model}'", flush=True)
t = time.time()
# Ép CPUExecutionProvider: CoreML trên Apple Silicon biên dịch u2net → treo.
try:
SESSIONS[model] = new_session(model, providers=["CPUExecutionProvider"])
except TypeError:
os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CPUExecutionProvider]"
SESSIONS[model] = new_session(model)
print(f"[model] '{model}' sẵn sàng sau {time.time() - t:.1f}s", flush=True)
return SESSIONS[model]
# ============================ REMOVE BACKGROUND ============================
def remove_object(src, opt):
"""Xóa nền, tối ưu chống cháy sáng để không mất góc/chi tiết mép."""
rgba = src.convert("RGBA")
session = get_session(opt["model"])
if not opt["anti_blowout"]:
return remove(rgba, session=session, post_process_mask=True)
work = ImageEnhance.Contrast(src.convert("RGB")).enhance(1.5)
work = ImageEnhance.Brightness(work).enhance(0.85)
mask = remove(work, session=session, only_mask=True, post_process_mask=True)
r = int(opt["recover"])
if r > 0:
k = r * 2 + 1
mask = mask.filter(ImageFilter.MaxFilter(k)).filter(ImageFilter.MinFilter(k))
out = rgba.copy()
out.putalpha(mask)
return out
def get_cut(src, img_hash, opt):
"""Cutout sau xóa nền, tái dùng cache theo (ảnh + model/anti/recover)."""
key = f"{img_hash}-{opt['model']}-{int(opt['anti_blowout'])}-{opt['recover']}"
path = os.path.join(CACHE, key + ".png")
if os.path.exists(path):
return Image.open(path).convert("RGBA"), True
cut = remove_object(src, opt)
cut.save(path)
return cut, False
def fit_square(cut, obj_scale, bg):
"""Crop sát object rồi đặt vào canvas vuông 1:1."""
bbox = cut.getbbox()
if bbox:
cut = cut.crop(bbox)
w, h = cut.size
side = max(1, int(round(max(w, h) / max(0.05, obj_scale))))
fill = (255, 255, 255, 255) if bg == "white" else (0, 0, 0, 0)
canvas = Image.new("RGBA", (side, side), fill)
canvas.alpha_composite(cut, ((side - w) // 2, (side - h) // 2))
return canvas
def object_polygon(cut, max_pts=48):
"""Đường viền (polygon chuẩn hoá 0..1) ôm sát object, suy từ alpha."""
W, H = cut.size
rect = [[0, 0], [1, 0], [1, 1], [0, 1]]
try:
import cv2
import numpy as np
except Exception: # noqa: BLE001
bb = cut.getbbox()
if not bb:
return rect
x0, y0, x1, y1 = bb
return [[x0 / W, y0 / H], [x1 / W, y0 / H], [x1 / W, y1 / H], [x0 / W, y1 / H]]
alpha = np.array(cut.getchannel("A"))
_, binimg = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
cnts, _ = cv2.findContours(binimg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not cnts:
return rect
c = max(cnts, key=cv2.contourArea)
peri = cv2.arcLength(c, True)
eps = 0.008 * peri
approx = cv2.approxPolyDP(c, eps, True)
while len(approx) > max_pts:
eps *= 1.3
approx = cv2.approxPolyDP(c, eps, True)
if len(approx) < 3:
return rect
return [[float(p[0][0]) / W, float(p[0][1]) / H] for p in approx]
def apply_poly(cut, poly):
"""Giữ alpha bên trong polygon (chuẩn hoá), xóa ngoài."""
W, H = cut.size
pts = [(p[0] * W, p[1] * H) for p in poly]
mask = Image.new("L", (W, H), 0)
ImageDraw.Draw(mask).polygon(pts, fill=255)
out = cut.copy()
out.putalpha(ImageChops.multiply(cut.getchannel("A"), mask))
return out
def position(pos, canvas, item, margin):
cw, ch = canvas
iw, ih = item
table = {
"northwest": (margin, margin),
"north": ((cw - iw) // 2, margin),
"northeast": (cw - iw - margin, margin),
"center": ((cw - iw) // 2, (ch - ih) // 2),
"southwest": (margin, ch - ih - margin),
"south": ((cw - iw) // 2, ch - ih - margin),
"southeast": (cw - iw - margin, ch - ih - margin),
}
return table.get(pos, table["southeast"])
# ============================ TAB 1: TÁCH NỀN ============================
def make_object(src, img_hash, opt):
"""Xóa nền → (polygon crop nếu có) → object vuông 1:1 TRONG SUỐT (không frame/wm)."""
cut, cached = get_cut(src, img_hash, opt)
polygon = object_polygon(cut)
if opt.get("poly"):
cut = apply_poly(cut, opt["poly"])
canvas = fit_square(cut, 1.0, "transparent") # lưu ôm sát; tỉ lệ chỉnh ở bước ghép frame
out_name = f"{uuid.uuid4().hex[:12]}.png"
canvas.save(os.path.join(OUTPUT, out_name))
return out_name, cached, polygon
def parse_poly(form):
raw = form.get("poly")
if not raw:
return None
try:
pts = [[max(0.0, min(1.0, float(x))), max(0.0, min(1.0, float(y)))]
for x, y in json.loads(raw)]
except (ValueError, TypeError):
return None
return pts if len(pts) >= 3 else None
def build_opt(form):
return {
"model": form.get("model", "u2net"),
"anti_blowout": form.get("anti_blowout", "true") == "true",
"recover": int(form.get("recover", 2) or 0),
"poly": parse_poly(form),
}
def save_original(data, filename):
ext = os.path.splitext(filename)[1].lower()
if ext not in (".jpg", ".jpeg", ".png", ".webp"):
ext = ".png"
name = f"orig-{uuid.uuid4().hex[:12]}{ext}"
with open(os.path.join(OUTPUT, name), "wb") as fh:
fh.write(data)
return name
def run_compare(src, img_hash, filename, original, base_opt):
def run(level):
opt = dict(base_opt, **COMPARE_PRESETS[level])
t = time.time()
out, cached, polygon = make_object(src, img_hash, opt)
print(f"[compare] {filename} · {level}{'cache' if cached else f'{time.time()-t:.1f}s'}", flush=True)
return {"name": filename, "output": out, "original": original,
"ok": True, "level": level, "cached": cached, "polygon": polygon}
levels = ("low", "medium", "high")
with ThreadPoolExecutor(max_workers=3) as ex:
futures = {lv: ex.submit(run, lv) for lv in levels}
out = []
for lv in levels:
try:
out.append(futures[lv].result())
except Exception as e: # noqa: BLE001
print(f"[compare] {filename} · {lv}{e}", flush=True)
out.append({"name": filename, "ok": False, "level": lv, "error": str(e)})
return out
@app.post("/api/process")
def process():
opt = build_opt(request.form)
compare = request.form.get("compare", "false") == "true"
files = request.files.getlist("images")
if not files:
return jsonify(error="chưa chọn ảnh"), 400
total = len(files)
print(f"[batch] nhận {total} ảnh · {'so sánh 3 mức' if compare else opt['model']}", flush=True)
results = []
for i, f in enumerate(files, 1):
t = time.time()
try:
data = f.read()
img_hash = hashlib.sha1(data).hexdigest()[:16]
original = save_original(data, f.filename)
src = ImageOps.exif_transpose(Image.open(io.BytesIO(data))).convert("RGBA")
if compare:
results.extend(run_compare(src, img_hash, f.filename, original, opt))
else:
out, cached, polygon = make_object(src, img_hash, opt)
results.append({"name": f.filename, "output": out, "original": original,
"ok": True, "level": None, "cached": cached, "polygon": polygon})
print(f"[{i}/{total}] {f.filename}{time.time() - t:.1f}s", flush=True)
except Exception as e: # noqa: BLE001
print(f"[{i}/{total}] {f.filename}{e}", flush=True)
results.append({"name": f.filename, "ok": False, "error": str(e)})
print("[batch] hoàn tất.", flush=True)
return jsonify(results=results)
@app.post("/api/recrop")
def recrop():
"""Render lại 1 object với vùng chọn polygon, dùng lại ảnh gốc + cache cutout."""
original = request.form.get("original", "")
path = os.path.join(OUTPUT, os.path.basename(original))
if not original or not os.path.exists(path):
return jsonify(error="không tìm thấy ảnh gốc"), 404
opt = build_opt(request.form)
level = request.form.get("level", "")
if level in COMPARE_PRESETS:
opt.update(COMPARE_PRESETS[level])
with open(path, "rb") as fh:
data = fh.read()
img_hash = hashlib.sha1(data).hexdigest()[:16]
src = ImageOps.exif_transpose(Image.open(io.BytesIO(data))).convert("RGBA")
out, cached, polygon = make_object(src, img_hash, opt)
return jsonify(ok=True, output=out, original=original, level=level or None,
cached=cached, polygon=polygon)
# ============================ THƯ VIỆN OBJECT ============================
def safe_name(name):
name = re.sub(r"[^\w\-]+", "_", (name or "").strip())
return name[:60] or uuid.uuid4().hex[:8]
@app.post("/api/save-object")
def save_object():
out = os.path.basename(request.form.get("output", ""))
src_path = os.path.join(OUTPUT, out)
if not out or not os.path.exists(src_path):
return jsonify(error="không tìm thấy ảnh đã xử lý"), 404
name = safe_name(request.form.get("name"))
Image.open(src_path).convert("RGBA").save(os.path.join(OBJECTS, name + ".png"))
print(f"[object] đã lưu '{name}'", flush=True)
return jsonify(ok=True, name=name, url=f"/objects/{name}.png")
@app.get("/api/objects")
def list_objects():
items = [{"name": fn[:-4], "url": f"/objects/{fn}"}
for fn in sorted(os.listdir(OBJECTS)) if fn.endswith(".png")]
return jsonify(items=items)
@app.delete("/api/objects/<name>")
def del_object(name):
p = os.path.join(OBJECTS, os.path.basename(name) + ".png")
if os.path.exists(p):
os.remove(p)
return jsonify(ok=True)
@app.get("/objects/<path:name>")
def serve_object(name):
return send_from_directory(OBJECTS, name)
# ============================ TAB 2: FRAME / WATERMARK ============================
def bg_color(bg):
"""Trả (r,g,b,255) cho mã hex, hoặc None nếu trong suốt."""
if not bg or bg == "transparent":
return None
h = bg.lstrip("#")
if len(h) == 3:
h = "".join(c * 2 for c in h)
try:
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16), 255)
except (ValueError, IndexError):
return None
def asset_file(kind, aid):
return os.path.join(ASSETS, kind, f"{os.path.basename(aid)}.png")
@app.post("/api/asset")
def add_asset():
kind = request.form.get("kind")
if kind not in ("frame", "watermark"):
return jsonify(error="kind không hợp lệ"), 400
f = request.files.get("file")
if not f:
return jsonify(error="thiếu file"), 400
aid = uuid.uuid4().hex[:12]
Image.open(f.stream).convert("RGBA").save(asset_file(kind, aid))
return jsonify(ok=True, id=aid, url=f"/api/asset/{kind}/{aid}")
@app.get("/api/asset/<kind>")
def list_assets(kind):
d = os.path.join(ASSETS, kind)
items = []
if os.path.isdir(d):
for fn in sorted(os.listdir(d)):
if fn.endswith(".png"):
aid = fn[:-4]
items.append({"id": aid, "url": f"/api/asset/{kind}/{aid}"})
return jsonify(items=items)
@app.get("/api/asset/<kind>/<aid>")
def get_asset(kind, aid):
p = asset_file(kind, aid)
if not os.path.exists(p):
return jsonify(error="chưa có"), 404
return send_file(p, mimetype="image/png")
@app.delete("/api/asset/<kind>/<aid>")
def del_asset(kind, aid):
p = asset_file(kind, aid)
if os.path.exists(p):
os.remove(p)
return jsonify(ok=True)
@app.post("/api/compose")
def compose():
"""Ghép object + frame + watermark → render & LƯU file, trả filename."""
name = os.path.basename(request.form.get("object", ""))
op = os.path.join(OBJECTS, name + ".png")
if not name or not os.path.exists(op):
return jsonify(error="không tìm thấy object"), 404
obj = Image.open(op).convert("RGBA")
obj_scale = float(request.form.get("obj_scale", 100)) / 100
obj = fit_square(obj, obj_scale, "transparent") # đặt lại tỉ lệ object trong khung
fill = bg_color(request.form.get("bg", "transparent")) # None = trong suốt
border = int(request.form.get("border", 0) or 0)
frame_id = request.form.get("frame", "")
wm_id = request.form.get("watermark", "")
wm_opacity = float(request.form.get("wm_opacity", 60)) / 100
wm_pos = request.form.get("wm_pos", "southeast")
wm_scale = float(request.form.get("wm_scale", 25)) / 100
if fill:
canvas = Image.new("RGBA", obj.size, fill)
canvas.alpha_composite(obj)
else:
canvas = obj.copy()
if border > 0:
canvas = ImageOps.expand(canvas, border=border, fill=fill or (0, 0, 0, 0))
if frame_id and os.path.exists(asset_file("frame", frame_id)):
frame = Image.open(asset_file("frame", frame_id)).convert("RGBA").resize(canvas.size)
canvas.alpha_composite(frame)
if wm_id and os.path.exists(asset_file("watermark", wm_id)):
wm = Image.open(asset_file("watermark", wm_id)).convert("RGBA")
tw = max(1, int(canvas.width * wm_scale))
th = max(1, int(wm.height * tw / wm.width))
wm = wm.resize((tw, th))
wm.putalpha(wm.split()[3].point(lambda a: int(a * wm_opacity)))
margin = int(canvas.width * 0.03)
x, y = position(wm_pos, canvas.size, wm.size, margin)
canvas.alpha_composite(wm, (x, y))
uid = uuid.uuid4().hex[:12]
if fill: # nền đặc → JPG
out_name = f"{uid}.jpg"
canvas.convert("RGB").save(os.path.join(OUTPUT, out_name), quality=92)
else: # trong suốt → PNG
out_name = f"{uid}.png"
canvas.save(os.path.join(OUTPUT, out_name))
print(f"[compose] {name} + frame={frame_id or '-'} + wm={wm_id or '-'}{out_name}", flush=True)
return jsonify(ok=True, output=out_name, url=f"/output/{out_name}")
@app.get("/")
def index():
return send_from_directory("static", "index.html")
@app.get("/output/<path:name>")
def output(name):
return send_from_directory(OUTPUT, name)
if __name__ == "__main__":
print("→ Option 1 (Python/rembg) đang khởi động…", flush=True)
get_session("u2net")
print("→ Sẵn sàng tại http://localhost:8001", flush=True)
app.run(host="0.0.0.0", port=8001, debug=False, threaded=True)