From c1214662a30e930ff40135207596335d8619ca15 Mon Sep 17 00:00:00 2001 From: nguentrungthat Date: Fri, 21 Mar 2025 09:00:39 +0700 Subject: [PATCH] Update training --- .gitignore | 3 +- src/server/routes/detect.py | 2 +- src/server/train.py | 77 +++++++++++++++++++++++++++++-------- 3 files changed, 65 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 0240c44..82b4c67 100755 --- a/.gitignore +++ b/.gitignore @@ -30,4 +30,5 @@ yarn-error.log* src/server/venv/* __pycache__/ -my_database.db \ No newline at end of file +my_database.db +latest_detect \ No newline at end of file diff --git a/src/server/routes/detect.py b/src/server/routes/detect.py index 57ef367..d75041d 100755 --- a/src/server/routes/detect.py +++ b/src/server/routes/detect.py @@ -111,7 +111,7 @@ def detect_image(): file.save(image_path) # Run with model - results = model(image_path, conf=0.6) + results = model(image_path, conf=0.5) points = [] # Result predict diff --git a/src/server/train.py b/src/server/train.py index 37fcb86..d74c397 100755 --- a/src/server/train.py +++ b/src/server/train.py @@ -13,6 +13,8 @@ import sqlite3 DATA_SPLIT_FOLDER = "model_datasets" IMAGE_EXTENSION = ".png" LOG_FILE = "training_logs.log" +LATEST_DETECT_IMAGES = "latest_detect/images" +LATEST_DETECT_LABELS = "latest_detect/labels" def get_label_names(): conn = sqlite3.connect("my_database.db") @@ -84,6 +86,17 @@ split_idx = int(len(image_files) * train_ratio) train_files = image_files[:split_idx] val_files = image_files[split_idx:] +latest_img_folder = [f for f in os.listdir(LATEST_DETECT_IMAGES) if f.endswith(IMAGE_EXTENSION)] +train_files += latest_img_folder +# Copy latest detect images to train folder +for img_file in latest_img_folder: + img_path = os.path.join(LATEST_DETECT_IMAGES, img_file) + label_file = os.path.splitext(img_file)[0] + ".txt" + label_path = os.path.join(LATEST_DETECT_LABELS, label_file) + + if os.path.exists(img_path) and os.path.exists(label_path) and not os.path.exists(os.path.join(image_folder, img_file)): + shutil.copy(img_path, os.path.join(image_folder, img_file)) + shutil.copy(label_path, os.path.join(label_folder, label_file)) def log_message(message: str): @@ -146,17 +159,43 @@ def call_reload_model_api(base_url="http://localhost:5000"): log_message(f" Error calling reload model API: {e} \n") return {"error": str(e)} +def copy_files(src_folder, dest_folder, num_files=20): + if not os.path.exists(dest_folder): + os.makedirs(dest_folder) + + files = sorted( + glob.glob(os.path.join(src_folder, "*")), + key=os.path.getmtime, # Sort by modification time (latest first) + reverse=True + )[:num_files] # Get top N files + + for file in files: + try: + shutil.copy(file, os.path.join(dest_folder, os.path.basename(file))) + # log_message(f"Copied: {file} -> {dest_folder}") + except Exception as e: + log_message(f"Error copying {file}: {e}") + def clear_images_source(): - paths = [image_folder, label_folder] - # paths = [image_folder, label_folder, train_img_folder, train_lbl_folder, val_img_folder, val_lbl_folder] + # Copy top 20 images and labels + copy_files(image_folder, LATEST_DETECT_IMAGES, 20) + copy_files(label_folder, LATEST_DETECT_LABELS, 20) + + # Delete all files in source folders + for path in [image_folder, label_folder]: + if not os.path.exists(path): + log_message(f"Path does not exist: {path}") + continue - for path in paths: for file in glob.glob(os.path.join(path, "*")): - if os.path.isfile(file): - os.remove(file) + if os.path.isfile(file): + try: + os.remove(file) + # log_message(f"Deleted: {file}") + except Exception as e: + log_message(f"Error deleting {file}: {e}") - - log_message(f"Delete source image success \n") + log_message("Delete source image success \n") log_message("END " + ("=" * 20) + "\n") @@ -185,7 +224,15 @@ def move_files(file_list, dest_img_folder, dest_lbl_folder): return copied_images, copied_labels -def train_yolo_model(pretrained_model: str, dataset_folder: str,project_name: str, name: str, epochs: int = 50, batch_size: int = 16, img_size: int = 640, lr: float = 0.001): +def train_yolo_model( + pretrained_model: str, + dataset_folder: str, + project_name: str, + name: str, + epochs: int = 50, + batch_size: int = 16, + img_size: int = 640, + lr: float = 0.001): dataset_yaml = os.path.join(dataset_folder, "data.yaml") if not os.path.exists(dataset_yaml): @@ -207,20 +254,17 @@ def train_yolo_model(pretrained_model: str, dataset_folder: str,project_name: st optimizer="AdamW", lr0=lr, weight_decay=0.0005, - patience=10, + patience=0, verbose=True, project=project_name, - name=name + name=name ) with open(LOG_FILE, "a") as log: log.write(f"\n[{datetime.datetime.now()}] Train completed\n") call_reload_model_api() - clear_images_source() - - - + clear_images_source() # 🚀 Copy files train_copied_imgs, train_copied_lbls = move_files(train_files, train_img_folder, train_lbl_folder) @@ -248,6 +292,9 @@ else: train_yolo_model(pretrained_model = get_latest_model( trained_model_folder=TRAINED_MODEL_FOLDER, default_model=PRETRAINED_MODEL), - dataset_folder = dataset_folder, epochs = 50, name=today_str, project_name=model_folder_name + dataset_folder = dataset_folder, + epochs = 50, + name=today_str, + project_name=model_folder_name, )