utopia-surveillance-tool/train_model.py

179 lines
5.8 KiB
Python

"""
train_model.py
──────────────
Fine-tune yolo26n.pt on the extracted person detection dataset.
Optimized for NVIDIA RTX A5000 (16GB VRAM):
- Mixed precision (AMP) enabled
- Batch size 16, image size 640
- Early stopping with patience 15
- Full YOLO augmentation pipeline
"""
import os
import sys
import json
import logging
from pathlib import Path
from datetime import datetime
from ultralytics import YOLO
import pipeline_config as cfg
# ──────────────────────────────────────────────
# Logging
# ──────────────────────────────────────────────
os.makedirs(cfg.LOG_DIR, exist_ok=True)
log_file = os.path.join(cfg.LOG_DIR, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
def verify_dataset() -> str:
"""Verify dataset structure and return path to dataset.yaml."""
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
if not os.path.exists(yaml_path):
logger.error(f"dataset.yaml not found at {yaml_path}")
logger.error("Run extract_dataset.py and split_dataset.py first!")
sys.exit(1)
# Check directories exist
train_images = os.path.join(cfg.DATASET_DIR, "images", "train")
test_images = os.path.join(cfg.DATASET_DIR, "images", "test")
if not os.path.exists(train_images):
logger.error(f"Train images directory not found: {train_images}")
sys.exit(1)
if not os.path.exists(test_images):
logger.error(f"Test images directory not found: {test_images}")
sys.exit(1)
n_train = len([f for f in os.listdir(train_images) if f.endswith(".jpg")])
n_test = len([f for f in os.listdir(test_images) if f.endswith(".jpg")])
logger.info(f"Dataset verified:")
logger.info(f" Train images: {n_train}")
logger.info(f" Test images: {n_test}")
if n_train == 0:
logger.error("No training images found!")
sys.exit(1)
if n_test == 0:
logger.warning("No test images found — training will proceed without validation.")
return yaml_path
def train_model() -> str:
"""
Main entry point: train yolo26n.pt on the person detection dataset.
Returns path to best weights.
"""
logger.info("=" * 60)
logger.info("MODEL TRAINING STARTED")
logger.info("=" * 60)
# Verify dataset
yaml_path = verify_dataset()
logger.info(f"Dataset config: {yaml_path}")
# Load base model
logger.info(f"Loading base model: {cfg.TRAIN_MODEL}")
model = YOLO(cfg.TRAIN_MODEL)
logger.info("Base model loaded ✓")
# Training configuration
train_args = {
"data": yaml_path,
"epochs": cfg.TRAIN_EPOCHS,
"batch": cfg.TRAIN_BATCH,
"imgsz": cfg.TRAIN_IMGSZ,
"device": 0, # GPU
"workers": cfg.TRAIN_WORKERS,
"patience": cfg.EARLY_STOP_PATIENCE,
"project": cfg.TRAIN_PROJECT,
"name": cfg.TRAIN_NAME,
"exist_ok": True, # Overwrite previous run
"pretrained": True,
"save": True,
"save_period": 10, # Save checkpoint every 10 epochs
"val": True,
"plots": True, # Generate training plots
"verbose": True,
# Augmentation (YOLO defaults are good, but explicit for clarity)
"hsv_h": 0.015,
"hsv_s": 0.7,
"hsv_v": 0.4,
"degrees": 0.0,
"translate": 0.1,
"scale": 0.5,
"shear": 0.0,
"flipud": 0.0, # No vertical flip (people don't appear upside down)
"fliplr": 0.5, # Horizontal flip is fine
"mosaic": 1.0,
"mixup": 0.1,
}
logger.info("Training configuration:")
for k, v in train_args.items():
logger.info(f" {k}: {v}")
# Start training
logger.info("\n" + "" * 40)
logger.info("TRAINING IN PROGRESS...")
logger.info("" * 40)
start_time = datetime.now()
results = model.train(**train_args)
end_time = datetime.now()
duration = end_time - start_time
# Results
logger.info("\n" + "=" * 60)
logger.info("TRAINING COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {duration}")
# Find best weights
best_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "best.pt")
last_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "last.pt")
if os.path.exists(best_weights):
logger.info(f"Best weights: {best_weights}")
# Copy best weights to project root for easy access
import shutil
output_model = os.path.join(cfg.BASE_DIR, "person_detector_best.pt")
shutil.copy2(best_weights, output_model)
logger.info(f"Copied best model to: {output_model}")
else:
logger.warning(f"Best weights not found at expected path: {best_weights}")
best_weights = last_weights
# Save training summary
summary = {
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"duration_seconds": duration.total_seconds(),
"best_weights": best_weights,
"train_args": {k: str(v) for k, v in train_args.items()},
}
summary_path = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "training_summary.json")
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
with open(summary_path, "w") as f:
json.dump(summary, f, indent=2)
return best_weights
if __name__ == "__main__":
train_model()