105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
# services/detect_service.py
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
from typing import List
|
|
|
|
|
|
class DetectService:
|
|
def __init__(
|
|
self,
|
|
template_dir: str,
|
|
target_labels: List[str] = None,
|
|
threshold: float = 0.75,
|
|
overlap_thresh: float = 0.3,
|
|
):
|
|
"""
|
|
template_dir: thư mục chứa các template
|
|
target_labels: danh sách folder tương ứng các field cần detect, ví dụ ["username", "password", "buttons/login"]
|
|
"""
|
|
self.template_dir = os.path.abspath(template_dir)
|
|
self.threshold = threshold
|
|
self.overlap_thresh = overlap_thresh
|
|
self.target_labels = [t.lower() for t in target_labels] if target_labels else None
|
|
|
|
if not os.path.exists(self.template_dir):
|
|
raise FileNotFoundError(f"Template dir not found: {self.template_dir}")
|
|
|
|
# ----------------------------------------------------
|
|
def detect(self, screen):
|
|
"""
|
|
screen: ảnh chụp màn hình (numpy BGR)
|
|
return: list [(label, filename, top_left, bottom_right, score)]
|
|
"""
|
|
regions = []
|
|
|
|
for root, _, files in os.walk(self.template_dir):
|
|
rel_path = os.path.relpath(root, self.template_dir).replace("\\", "/").lower()
|
|
|
|
# nếu target_labels được định nghĩa thì chỉ detect những folder cần thiết
|
|
if self.target_labels and rel_path not in self.target_labels:
|
|
continue
|
|
|
|
for file in files:
|
|
if not file.lower().endswith((".png", ".jpg", ".jpeg")):
|
|
continue
|
|
|
|
template_path = os.path.join(root, file)
|
|
template = cv2.imread(template_path)
|
|
if template is None:
|
|
continue
|
|
|
|
res = cv2.matchTemplate(screen, template, cv2.TM_CCOEFF_NORMED)
|
|
loc = np.where(res >= self.threshold)
|
|
|
|
for pt in zip(*loc[::-1]):
|
|
top_left = (int(pt[0]), int(pt[1]))
|
|
bottom_right = (
|
|
int(pt[0] + template.shape[1]),
|
|
int(pt[1] + template.shape[0])
|
|
)
|
|
score = float(res[pt[1], pt[0]])
|
|
regions.append((rel_path, file, top_left, bottom_right, score))
|
|
|
|
return self.non_max_suppression(regions)
|
|
|
|
# ----------------------------------------------------
|
|
def non_max_suppression(self, regions):
|
|
"""Giảm trùng vùng detect"""
|
|
if not regions:
|
|
return []
|
|
|
|
boxes = []
|
|
for label, file, top_left, bottom_right, score in regions:
|
|
x1, y1 = top_left
|
|
x2, y2 = bottom_right
|
|
boxes.append([x1, y1, x2, y2, score, label, file])
|
|
boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
|
|
|
|
pick = []
|
|
while boxes:
|
|
current = boxes.pop(0)
|
|
pick.append(current)
|
|
boxes = [
|
|
b for b in boxes
|
|
if b[5] != current[5] and self.iou(b, current) < self.overlap_thresh
|
|
]
|
|
|
|
return [
|
|
(b[5], b[6], (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), b[4])
|
|
for b in pick
|
|
]
|
|
|
|
# ----------------------------------------------------
|
|
def iou(self, boxA, boxB):
|
|
"""Intersection-over-Union"""
|
|
xA = max(boxA[0], boxB[0])
|
|
yA = max(boxA[1], boxB[1])
|
|
xB = min(boxA[2], boxB[2])
|
|
yB = min(boxA[3], boxB[3])
|
|
interArea = max(0, xB - xA) * max(0, yB - yA)
|
|
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
|
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
|
iou = interArea / float(boxAArea + boxBArea - interArea + 1e-5)
|
|
return iou
|