facebook-tool/services/detect_service.py

105 lines
3.7 KiB
Python

# services/detect_service.py
import os
import cv2
import numpy as np
from typing import List
class DetectService:
def __init__(
self,
template_dir: str,
target_labels: List[str] = None,
threshold: float = 0.75,
overlap_thresh: float = 0.3,
):
"""
template_dir: thư mục chứa các template
target_labels: danh sách folder tương ứng các field cần detect, ví dụ ["username", "password", "buttons/login"]
"""
self.template_dir = os.path.abspath(template_dir)
self.threshold = threshold
self.overlap_thresh = overlap_thresh
self.target_labels = [t.lower() for t in target_labels] if target_labels else None
if not os.path.exists(self.template_dir):
raise FileNotFoundError(f"Template dir not found: {self.template_dir}")
# ----------------------------------------------------
def detect(self, screen):
"""
screen: ảnh chụp màn hình (numpy BGR)
return: list [(label, filename, top_left, bottom_right, score)]
"""
regions = []
for root, _, files in os.walk(self.template_dir):
rel_path = os.path.relpath(root, self.template_dir).replace("\\", "/").lower()
# nếu target_labels được định nghĩa thì chỉ detect những folder cần thiết
if self.target_labels and rel_path not in self.target_labels:
continue
for file in files:
if not file.lower().endswith((".png", ".jpg", ".jpeg")):
continue
template_path = os.path.join(root, file)
template = cv2.imread(template_path)
if template is None:
continue
res = cv2.matchTemplate(screen, template, cv2.TM_CCOEFF_NORMED)
loc = np.where(res >= self.threshold)
for pt in zip(*loc[::-1]):
top_left = (int(pt[0]), int(pt[1]))
bottom_right = (
int(pt[0] + template.shape[1]),
int(pt[1] + template.shape[0])
)
score = float(res[pt[1], pt[0]])
regions.append((rel_path, file, top_left, bottom_right, score))
return self.non_max_suppression(regions)
# ----------------------------------------------------
def non_max_suppression(self, regions):
"""Giảm trùng vùng detect"""
if not regions:
return []
boxes = []
for label, file, top_left, bottom_right, score in regions:
x1, y1 = top_left
x2, y2 = bottom_right
boxes.append([x1, y1, x2, y2, score, label, file])
boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
pick = []
while boxes:
current = boxes.pop(0)
pick.append(current)
boxes = [
b for b in boxes
if b[5] != current[5] and self.iou(b, current) < self.overlap_thresh
]
return [
(b[5], b[6], (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), b[4])
for b in pick
]
# ----------------------------------------------------
def iou(self, boxA, boxB):
"""Intersection-over-Union"""
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
iou = interArea / float(boxAArea + boxBArea - interArea + 1e-5)
return iou