Update training
This commit is contained in:
parent
db1f306145
commit
c1214662a3
|
|
@ -30,4 +30,5 @@ yarn-error.log*
|
||||||
src/server/venv/*
|
src/server/venv/*
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
my_database.db
|
my_database.db
|
||||||
|
latest_detect
|
||||||
|
|
@ -111,7 +111,7 @@ def detect_image():
|
||||||
file.save(image_path)
|
file.save(image_path)
|
||||||
|
|
||||||
# Run with model
|
# Run with model
|
||||||
results = model(image_path, conf=0.6)
|
results = model(image_path, conf=0.5)
|
||||||
points = []
|
points = []
|
||||||
|
|
||||||
# Result predict
|
# Result predict
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ import sqlite3
|
||||||
DATA_SPLIT_FOLDER = "model_datasets"
|
DATA_SPLIT_FOLDER = "model_datasets"
|
||||||
IMAGE_EXTENSION = ".png"
|
IMAGE_EXTENSION = ".png"
|
||||||
LOG_FILE = "training_logs.log"
|
LOG_FILE = "training_logs.log"
|
||||||
|
LATEST_DETECT_IMAGES = "latest_detect/images"
|
||||||
|
LATEST_DETECT_LABELS = "latest_detect/labels"
|
||||||
|
|
||||||
def get_label_names():
|
def get_label_names():
|
||||||
conn = sqlite3.connect("my_database.db")
|
conn = sqlite3.connect("my_database.db")
|
||||||
|
|
@ -84,6 +86,17 @@ split_idx = int(len(image_files) * train_ratio)
|
||||||
train_files = image_files[:split_idx]
|
train_files = image_files[:split_idx]
|
||||||
val_files = image_files[split_idx:]
|
val_files = image_files[split_idx:]
|
||||||
|
|
||||||
|
latest_img_folder = [f for f in os.listdir(LATEST_DETECT_IMAGES) if f.endswith(IMAGE_EXTENSION)]
|
||||||
|
train_files += latest_img_folder
|
||||||
|
# Copy latest detect images to train folder
|
||||||
|
for img_file in latest_img_folder:
|
||||||
|
img_path = os.path.join(LATEST_DETECT_IMAGES, img_file)
|
||||||
|
label_file = os.path.splitext(img_file)[0] + ".txt"
|
||||||
|
label_path = os.path.join(LATEST_DETECT_LABELS, label_file)
|
||||||
|
|
||||||
|
if os.path.exists(img_path) and os.path.exists(label_path) and not os.path.exists(os.path.join(image_folder, img_file)):
|
||||||
|
shutil.copy(img_path, os.path.join(image_folder, img_file))
|
||||||
|
shutil.copy(label_path, os.path.join(label_folder, label_file))
|
||||||
|
|
||||||
|
|
||||||
def log_message(message: str):
|
def log_message(message: str):
|
||||||
|
|
@ -146,17 +159,43 @@ def call_reload_model_api(base_url="http://localhost:5000"):
|
||||||
log_message(f" Error calling reload model API: {e} \n")
|
log_message(f" Error calling reload model API: {e} \n")
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def copy_files(src_folder, dest_folder, num_files=20):
|
||||||
|
if not os.path.exists(dest_folder):
|
||||||
|
os.makedirs(dest_folder)
|
||||||
|
|
||||||
|
files = sorted(
|
||||||
|
glob.glob(os.path.join(src_folder, "*")),
|
||||||
|
key=os.path.getmtime, # Sort by modification time (latest first)
|
||||||
|
reverse=True
|
||||||
|
)[:num_files] # Get top N files
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
shutil.copy(file, os.path.join(dest_folder, os.path.basename(file)))
|
||||||
|
# log_message(f"Copied: {file} -> {dest_folder}")
|
||||||
|
except Exception as e:
|
||||||
|
log_message(f"Error copying {file}: {e}")
|
||||||
|
|
||||||
def clear_images_source():
|
def clear_images_source():
|
||||||
paths = [image_folder, label_folder]
|
# Copy top 20 images and labels
|
||||||
# paths = [image_folder, label_folder, train_img_folder, train_lbl_folder, val_img_folder, val_lbl_folder]
|
copy_files(image_folder, LATEST_DETECT_IMAGES, 20)
|
||||||
|
copy_files(label_folder, LATEST_DETECT_LABELS, 20)
|
||||||
|
|
||||||
|
# Delete all files in source folders
|
||||||
|
for path in [image_folder, label_folder]:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
log_message(f"Path does not exist: {path}")
|
||||||
|
continue
|
||||||
|
|
||||||
for path in paths:
|
|
||||||
for file in glob.glob(os.path.join(path, "*")):
|
for file in glob.glob(os.path.join(path, "*")):
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
os.remove(file)
|
try:
|
||||||
|
os.remove(file)
|
||||||
|
# log_message(f"Deleted: {file}")
|
||||||
|
except Exception as e:
|
||||||
|
log_message(f"Error deleting {file}: {e}")
|
||||||
|
|
||||||
|
log_message("Delete source image success \n")
|
||||||
log_message(f"Delete source image success \n")
|
|
||||||
log_message("END " + ("=" * 20) + "\n")
|
log_message("END " + ("=" * 20) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -185,7 +224,15 @@ def move_files(file_list, dest_img_folder, dest_lbl_folder):
|
||||||
return copied_images, copied_labels
|
return copied_images, copied_labels
|
||||||
|
|
||||||
|
|
||||||
def train_yolo_model(pretrained_model: str, dataset_folder: str,project_name: str, name: str, epochs: int = 50, batch_size: int = 16, img_size: int = 640, lr: float = 0.001):
|
def train_yolo_model(
|
||||||
|
pretrained_model: str,
|
||||||
|
dataset_folder: str,
|
||||||
|
project_name: str,
|
||||||
|
name: str,
|
||||||
|
epochs: int = 50,
|
||||||
|
batch_size: int = 16,
|
||||||
|
img_size: int = 640,
|
||||||
|
lr: float = 0.001):
|
||||||
dataset_yaml = os.path.join(dataset_folder, "data.yaml")
|
dataset_yaml = os.path.join(dataset_folder, "data.yaml")
|
||||||
|
|
||||||
if not os.path.exists(dataset_yaml):
|
if not os.path.exists(dataset_yaml):
|
||||||
|
|
@ -207,20 +254,17 @@ def train_yolo_model(pretrained_model: str, dataset_folder: str,project_name: st
|
||||||
optimizer="AdamW",
|
optimizer="AdamW",
|
||||||
lr0=lr,
|
lr0=lr,
|
||||||
weight_decay=0.0005,
|
weight_decay=0.0005,
|
||||||
patience=10,
|
patience=0,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
project=project_name,
|
project=project_name,
|
||||||
name=name
|
name=name
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(LOG_FILE, "a") as log:
|
with open(LOG_FILE, "a") as log:
|
||||||
log.write(f"\n[{datetime.datetime.now()}] Train completed\n")
|
log.write(f"\n[{datetime.datetime.now()}] Train completed\n")
|
||||||
|
|
||||||
call_reload_model_api()
|
call_reload_model_api()
|
||||||
clear_images_source()
|
clear_images_source()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 🚀 Copy files
|
# 🚀 Copy files
|
||||||
train_copied_imgs, train_copied_lbls = move_files(train_files, train_img_folder, train_lbl_folder)
|
train_copied_imgs, train_copied_lbls = move_files(train_files, train_img_folder, train_lbl_folder)
|
||||||
|
|
@ -248,6 +292,9 @@ else:
|
||||||
train_yolo_model(pretrained_model = get_latest_model(
|
train_yolo_model(pretrained_model = get_latest_model(
|
||||||
trained_model_folder=TRAINED_MODEL_FOLDER,
|
trained_model_folder=TRAINED_MODEL_FOLDER,
|
||||||
default_model=PRETRAINED_MODEL),
|
default_model=PRETRAINED_MODEL),
|
||||||
dataset_folder = dataset_folder, epochs = 50, name=today_str, project_name=model_folder_name
|
dataset_folder = dataset_folder,
|
||||||
|
epochs = 50,
|
||||||
|
name=today_str,
|
||||||
|
project_name=model_folder_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue