105 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			105 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
# services/detect_service.py
 | 
						|
import os
 | 
						|
import cv2
 | 
						|
import numpy as np
 | 
						|
from typing import List
 | 
						|
 | 
						|
 | 
						|
class DetectService:
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        template_dir: str,
 | 
						|
        target_labels: List[str] = None,
 | 
						|
        threshold: float = 0.75,
 | 
						|
        overlap_thresh: float = 0.3,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        template_dir: thư mục chứa các template
 | 
						|
        target_labels: danh sách folder tương ứng các field cần detect, ví dụ ["username", "password", "buttons/login"]
 | 
						|
        """
 | 
						|
        self.template_dir = os.path.abspath(template_dir)
 | 
						|
        self.threshold = threshold
 | 
						|
        self.overlap_thresh = overlap_thresh
 | 
						|
        self.target_labels = [t.lower() for t in target_labels] if target_labels else None
 | 
						|
 | 
						|
        if not os.path.exists(self.template_dir):
 | 
						|
            raise FileNotFoundError(f"Template dir not found: {self.template_dir}")
 | 
						|
 | 
						|
    # ----------------------------------------------------
 | 
						|
    def detect(self, screen):
 | 
						|
        """
 | 
						|
        screen: ảnh chụp màn hình (numpy BGR)
 | 
						|
        return: list [(label, filename, top_left, bottom_right, score)]
 | 
						|
        """
 | 
						|
        regions = []
 | 
						|
 | 
						|
        for root, _, files in os.walk(self.template_dir):
 | 
						|
            rel_path = os.path.relpath(root, self.template_dir).replace("\\", "/").lower()
 | 
						|
            
 | 
						|
            # nếu target_labels được định nghĩa thì chỉ detect những folder cần thiết
 | 
						|
            if self.target_labels and rel_path not in self.target_labels:
 | 
						|
                continue
 | 
						|
 | 
						|
            for file in files:
 | 
						|
                if not file.lower().endswith((".png", ".jpg", ".jpeg")):
 | 
						|
                    continue
 | 
						|
 | 
						|
                template_path = os.path.join(root, file)
 | 
						|
                template = cv2.imread(template_path)
 | 
						|
                if template is None:
 | 
						|
                    continue
 | 
						|
 | 
						|
                res = cv2.matchTemplate(screen, template, cv2.TM_CCOEFF_NORMED)
 | 
						|
                loc = np.where(res >= self.threshold)
 | 
						|
 | 
						|
                for pt in zip(*loc[::-1]):
 | 
						|
                    top_left = (int(pt[0]), int(pt[1]))
 | 
						|
                    bottom_right = (
 | 
						|
                        int(pt[0] + template.shape[1]),
 | 
						|
                        int(pt[1] + template.shape[0])
 | 
						|
                    )
 | 
						|
                    score = float(res[pt[1], pt[0]])
 | 
						|
                    regions.append((rel_path, file, top_left, bottom_right, score))
 | 
						|
 | 
						|
        return self.non_max_suppression(regions)
 | 
						|
 | 
						|
    # ----------------------------------------------------
 | 
						|
    def non_max_suppression(self, regions):
 | 
						|
        """Giảm trùng vùng detect"""
 | 
						|
        if not regions:
 | 
						|
            return []
 | 
						|
 | 
						|
        boxes = []
 | 
						|
        for label, file, top_left, bottom_right, score in regions:
 | 
						|
            x1, y1 = top_left
 | 
						|
            x2, y2 = bottom_right
 | 
						|
            boxes.append([x1, y1, x2, y2, score, label, file])
 | 
						|
        boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
 | 
						|
 | 
						|
        pick = []
 | 
						|
        while boxes:
 | 
						|
            current = boxes.pop(0)
 | 
						|
            pick.append(current)
 | 
						|
            boxes = [
 | 
						|
                b for b in boxes
 | 
						|
                if b[5] != current[5] and self.iou(b, current) < self.overlap_thresh
 | 
						|
            ]
 | 
						|
 | 
						|
        return [
 | 
						|
            (b[5], b[6], (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), b[4])
 | 
						|
            for b in pick
 | 
						|
        ]
 | 
						|
 | 
						|
    # ----------------------------------------------------
 | 
						|
    def iou(self, boxA, boxB):
 | 
						|
        """Intersection-over-Union"""
 | 
						|
        xA = max(boxA[0], boxB[0])
 | 
						|
        yA = max(boxA[1], boxB[1])
 | 
						|
        xB = min(boxA[2], boxB[2])
 | 
						|
        yB = min(boxA[3], boxB[3])
 | 
						|
        interArea = max(0, xB - xA) * max(0, yB - yA)
 | 
						|
        boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
 | 
						|
        boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
 | 
						|
        iou = interArea / float(boxAArea + boxBArea - interArea + 1e-5)
 | 
						|
        return iou
 |