179 lines
5.8 KiB
Python
179 lines
5.8 KiB
Python
"""
|
|
train_model.py
|
|
──────────────
|
|
Fine-tune yolo26n.pt on the extracted person detection dataset.
|
|
|
|
Optimized for NVIDIA RTX A5000 (16GB VRAM):
|
|
- Mixed precision (AMP) enabled
|
|
- Batch size 16, image size 640
|
|
- Early stopping with patience 15
|
|
- Full YOLO augmentation pipeline
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
from ultralytics import YOLO
|
|
|
|
import pipeline_config as cfg
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Logging
|
|
# ──────────────────────────────────────────────
|
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
|
log_file = os.path.join(cfg.LOG_DIR, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(log_file),
|
|
logging.StreamHandler(sys.stdout),
|
|
],
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def verify_dataset() -> str:
|
|
"""Verify dataset structure and return path to dataset.yaml."""
|
|
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
|
|
if not os.path.exists(yaml_path):
|
|
logger.error(f"dataset.yaml not found at {yaml_path}")
|
|
logger.error("Run extract_dataset.py and split_dataset.py first!")
|
|
sys.exit(1)
|
|
|
|
# Check directories exist
|
|
train_images = os.path.join(cfg.DATASET_DIR, "images", "train")
|
|
test_images = os.path.join(cfg.DATASET_DIR, "images", "test")
|
|
|
|
if not os.path.exists(train_images):
|
|
logger.error(f"Train images directory not found: {train_images}")
|
|
sys.exit(1)
|
|
if not os.path.exists(test_images):
|
|
logger.error(f"Test images directory not found: {test_images}")
|
|
sys.exit(1)
|
|
|
|
n_train = len([f for f in os.listdir(train_images) if f.endswith(".jpg")])
|
|
n_test = len([f for f in os.listdir(test_images) if f.endswith(".jpg")])
|
|
|
|
logger.info(f"Dataset verified:")
|
|
logger.info(f" Train images: {n_train}")
|
|
logger.info(f" Test images: {n_test}")
|
|
|
|
if n_train == 0:
|
|
logger.error("No training images found!")
|
|
sys.exit(1)
|
|
if n_test == 0:
|
|
logger.warning("No test images found — training will proceed without validation.")
|
|
|
|
return yaml_path
|
|
|
|
|
|
def train_model() -> str:
|
|
"""
|
|
Main entry point: train yolo26n.pt on the person detection dataset.
|
|
Returns path to best weights.
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("MODEL TRAINING STARTED")
|
|
logger.info("=" * 60)
|
|
|
|
# Verify dataset
|
|
yaml_path = verify_dataset()
|
|
logger.info(f"Dataset config: {yaml_path}")
|
|
|
|
# Load base model
|
|
logger.info(f"Loading base model: {cfg.TRAIN_MODEL}")
|
|
model = YOLO(cfg.TRAIN_MODEL)
|
|
logger.info("Base model loaded ✓")
|
|
|
|
# Training configuration
|
|
train_args = {
|
|
"data": yaml_path,
|
|
"epochs": cfg.TRAIN_EPOCHS,
|
|
"batch": cfg.TRAIN_BATCH,
|
|
"imgsz": cfg.TRAIN_IMGSZ,
|
|
"device": 0, # GPU
|
|
"workers": cfg.TRAIN_WORKERS,
|
|
"patience": cfg.EARLY_STOP_PATIENCE,
|
|
"project": cfg.TRAIN_PROJECT,
|
|
"name": cfg.TRAIN_NAME,
|
|
"exist_ok": True, # Overwrite previous run
|
|
"pretrained": True,
|
|
"save": True,
|
|
"save_period": 10, # Save checkpoint every 10 epochs
|
|
"val": True,
|
|
"plots": True, # Generate training plots
|
|
"verbose": True,
|
|
# Augmentation (YOLO defaults are good, but explicit for clarity)
|
|
"hsv_h": 0.015,
|
|
"hsv_s": 0.7,
|
|
"hsv_v": 0.4,
|
|
"degrees": 0.0,
|
|
"translate": 0.1,
|
|
"scale": 0.5,
|
|
"shear": 0.0,
|
|
"flipud": 0.0, # No vertical flip (people don't appear upside down)
|
|
"fliplr": 0.5, # Horizontal flip is fine
|
|
"mosaic": 1.0,
|
|
"mixup": 0.1,
|
|
}
|
|
|
|
logger.info("Training configuration:")
|
|
for k, v in train_args.items():
|
|
logger.info(f" {k}: {v}")
|
|
|
|
# Start training
|
|
logger.info("\n" + "─" * 40)
|
|
logger.info("TRAINING IN PROGRESS...")
|
|
logger.info("─" * 40)
|
|
|
|
start_time = datetime.now()
|
|
results = model.train(**train_args)
|
|
end_time = datetime.now()
|
|
duration = end_time - start_time
|
|
|
|
# Results
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("TRAINING COMPLETE")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Duration: {duration}")
|
|
|
|
# Find best weights
|
|
best_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "best.pt")
|
|
last_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "last.pt")
|
|
|
|
if os.path.exists(best_weights):
|
|
logger.info(f"Best weights: {best_weights}")
|
|
# Copy best weights to project root for easy access
|
|
import shutil
|
|
output_model = os.path.join(cfg.BASE_DIR, "person_detector_best.pt")
|
|
shutil.copy2(best_weights, output_model)
|
|
logger.info(f"Copied best model to: {output_model}")
|
|
else:
|
|
logger.warning(f"Best weights not found at expected path: {best_weights}")
|
|
best_weights = last_weights
|
|
|
|
# Save training summary
|
|
summary = {
|
|
"start_time": start_time.isoformat(),
|
|
"end_time": end_time.isoformat(),
|
|
"duration_seconds": duration.total_seconds(),
|
|
"best_weights": best_weights,
|
|
"train_args": {k: str(v) for k, v in train_args.items()},
|
|
}
|
|
|
|
summary_path = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "training_summary.json")
|
|
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
|
|
with open(summary_path, "w") as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
return best_weights
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_model()
|