Update training
This commit is contained in:
parent
db1f306145
commit
c1214662a3
|
|
@ -30,4 +30,5 @@ yarn-error.log*
|
|||
src/server/venv/*
|
||||
__pycache__/
|
||||
|
||||
my_database.db
|
||||
my_database.db
|
||||
latest_detect
|
||||
|
|
@ -111,7 +111,7 @@ def detect_image():
|
|||
file.save(image_path)
|
||||
|
||||
# Run with model
|
||||
results = model(image_path, conf=0.6)
|
||||
results = model(image_path, conf=0.5)
|
||||
points = []
|
||||
|
||||
# Result predict
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import sqlite3
|
|||
DATA_SPLIT_FOLDER = "model_datasets"
|
||||
IMAGE_EXTENSION = ".png"
|
||||
LOG_FILE = "training_logs.log"
|
||||
LATEST_DETECT_IMAGES = "latest_detect/images"
|
||||
LATEST_DETECT_LABELS = "latest_detect/labels"
|
||||
|
||||
def get_label_names():
|
||||
conn = sqlite3.connect("my_database.db")
|
||||
|
|
@ -84,6 +86,17 @@ split_idx = int(len(image_files) * train_ratio)
|
|||
train_files = image_files[:split_idx]
|
||||
val_files = image_files[split_idx:]
|
||||
|
||||
latest_img_folder = [f for f in os.listdir(LATEST_DETECT_IMAGES) if f.endswith(IMAGE_EXTENSION)]
|
||||
train_files += latest_img_folder
|
||||
# Copy latest detect images to train folder
|
||||
for img_file in latest_img_folder:
|
||||
img_path = os.path.join(LATEST_DETECT_IMAGES, img_file)
|
||||
label_file = os.path.splitext(img_file)[0] + ".txt"
|
||||
label_path = os.path.join(LATEST_DETECT_LABELS, label_file)
|
||||
|
||||
if os.path.exists(img_path) and os.path.exists(label_path) and not os.path.exists(os.path.join(image_folder, img_file)):
|
||||
shutil.copy(img_path, os.path.join(image_folder, img_file))
|
||||
shutil.copy(label_path, os.path.join(label_folder, label_file))
|
||||
|
||||
|
||||
def log_message(message: str):
|
||||
|
|
@ -146,17 +159,43 @@ def call_reload_model_api(base_url="http://localhost:5000"):
|
|||
log_message(f" Error calling reload model API: {e} \n")
|
||||
return {"error": str(e)}
|
||||
|
||||
def copy_files(src_folder, dest_folder, num_files=20):
|
||||
if not os.path.exists(dest_folder):
|
||||
os.makedirs(dest_folder)
|
||||
|
||||
files = sorted(
|
||||
glob.glob(os.path.join(src_folder, "*")),
|
||||
key=os.path.getmtime, # Sort by modification time (latest first)
|
||||
reverse=True
|
||||
)[:num_files] # Get top N files
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
shutil.copy(file, os.path.join(dest_folder, os.path.basename(file)))
|
||||
# log_message(f"Copied: {file} -> {dest_folder}")
|
||||
except Exception as e:
|
||||
log_message(f"Error copying {file}: {e}")
|
||||
|
||||
def clear_images_source():
|
||||
paths = [image_folder, label_folder]
|
||||
# paths = [image_folder, label_folder, train_img_folder, train_lbl_folder, val_img_folder, val_lbl_folder]
|
||||
# Copy top 20 images and labels
|
||||
copy_files(image_folder, LATEST_DETECT_IMAGES, 20)
|
||||
copy_files(label_folder, LATEST_DETECT_LABELS, 20)
|
||||
|
||||
# Delete all files in source folders
|
||||
for path in [image_folder, label_folder]:
|
||||
if not os.path.exists(path):
|
||||
log_message(f"Path does not exist: {path}")
|
||||
continue
|
||||
|
||||
for path in paths:
|
||||
for file in glob.glob(os.path.join(path, "*")):
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
if os.path.isfile(file):
|
||||
try:
|
||||
os.remove(file)
|
||||
# log_message(f"Deleted: {file}")
|
||||
except Exception as e:
|
||||
log_message(f"Error deleting {file}: {e}")
|
||||
|
||||
|
||||
log_message(f"Delete source image success \n")
|
||||
log_message("Delete source image success \n")
|
||||
log_message("END " + ("=" * 20) + "\n")
|
||||
|
||||
|
||||
|
|
@ -185,7 +224,15 @@ def move_files(file_list, dest_img_folder, dest_lbl_folder):
|
|||
return copied_images, copied_labels
|
||||
|
||||
|
||||
def train_yolo_model(pretrained_model: str, dataset_folder: str,project_name: str, name: str, epochs: int = 50, batch_size: int = 16, img_size: int = 640, lr: float = 0.001):
|
||||
def train_yolo_model(
|
||||
pretrained_model: str,
|
||||
dataset_folder: str,
|
||||
project_name: str,
|
||||
name: str,
|
||||
epochs: int = 50,
|
||||
batch_size: int = 16,
|
||||
img_size: int = 640,
|
||||
lr: float = 0.001):
|
||||
dataset_yaml = os.path.join(dataset_folder, "data.yaml")
|
||||
|
||||
if not os.path.exists(dataset_yaml):
|
||||
|
|
@ -207,20 +254,17 @@ def train_yolo_model(pretrained_model: str, dataset_folder: str,project_name: st
|
|||
optimizer="AdamW",
|
||||
lr0=lr,
|
||||
weight_decay=0.0005,
|
||||
patience=10,
|
||||
patience=0,
|
||||
verbose=True,
|
||||
project=project_name,
|
||||
name=name
|
||||
name=name
|
||||
)
|
||||
|
||||
with open(LOG_FILE, "a") as log:
|
||||
log.write(f"\n[{datetime.datetime.now()}] Train completed\n")
|
||||
|
||||
call_reload_model_api()
|
||||
clear_images_source()
|
||||
|
||||
|
||||
|
||||
clear_images_source()
|
||||
|
||||
# 🚀 Copy files
|
||||
train_copied_imgs, train_copied_lbls = move_files(train_files, train_img_folder, train_lbl_folder)
|
||||
|
|
@ -248,6 +292,9 @@ else:
|
|||
train_yolo_model(pretrained_model = get_latest_model(
|
||||
trained_model_folder=TRAINED_MODEL_FOLDER,
|
||||
default_model=PRETRAINED_MODEL),
|
||||
dataset_folder = dataset_folder, epochs = 50, name=today_str, project_name=model_folder_name
|
||||
dataset_folder = dataset_folder,
|
||||
epochs = 50,
|
||||
name=today_str,
|
||||
project_name=model_folder_name,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue