feat: implement adaptive video frame extraction pipeline with YOLO-based auto-labeling and checkpointing
parent
a1da64e017
commit
a7184a5773
|
|
@ -2,3 +2,5 @@
|
|||
*.pt
|
||||
alerts/
|
||||
__pycache__
|
||||
dataset/
|
||||
video_data/
|
||||
|
|
@ -0,0 +1,504 @@
|
|||
"""
|
||||
extract_dataset.py
|
||||
──────────────────
|
||||
Extracts person-detection frames from surveillance videos using adaptive
|
||||
frame sampling and yolo26x.pt for auto-labeling.
|
||||
|
||||
Features:
|
||||
- Adaptive FPS: baseline 1 FPS, high 3 FPS (person), low 0.5 FPS (idle)
|
||||
- GPU-accelerated YOLO inference in batches
|
||||
- Per-video checkpointing for crash recovery
|
||||
- 50 GB dataset size cap
|
||||
- Organizes output in YOLO detection format
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import json
|
||||
import time
|
||||
import glob
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
import pipeline_config as cfg
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Logging
|
||||
# ──────────────────────────────────────────────
|
||||
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||
os.makedirs(cfg.CHECKPOINT_DIR, exist_ok=True)
|
||||
|
||||
log_file = os.path.join(cfg.LOG_DIR, f"extract_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ──────────────────────────────────────────────
|
||||
def get_dataset_size_gb(dataset_dir: str) -> float:
|
||||
"""Return total size of the dataset directory in GB."""
|
||||
total = 0
|
||||
for dirpath, _, filenames in os.walk(dataset_dir):
|
||||
for f in filenames:
|
||||
total += os.path.getsize(os.path.join(dirpath, f))
|
||||
return total / (1024 ** 3)
|
||||
|
||||
|
||||
def get_checkpoint_path(video_path: str) -> str:
|
||||
"""Return checkpoint file path for a given video."""
|
||||
video_name = Path(video_path).stem
|
||||
return os.path.join(cfg.CHECKPOINT_DIR, f"{video_name}.json")
|
||||
|
||||
|
||||
def load_checkpoint(video_path: str) -> dict | None:
|
||||
"""Load checkpoint for a video if it exists."""
|
||||
cp_path = get_checkpoint_path(video_path)
|
||||
if os.path.exists(cp_path):
|
||||
with open(cp_path, "r") as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
|
||||
|
||||
def save_checkpoint(video_path: str, data: dict):
|
||||
"""Save checkpoint for a video."""
|
||||
cp_path = get_checkpoint_path(video_path)
|
||||
with open(cp_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def mark_video_done(video_path: str, stats: dict):
|
||||
"""Mark a video as fully processed."""
|
||||
stats["done"] = True
|
||||
save_checkpoint(video_path, stats)
|
||||
|
||||
|
||||
def discover_videos() -> dict[str, list[str]]:
|
||||
"""
|
||||
Discover all camera directories and their video files.
|
||||
Returns {camera_name: [video_paths]}.
|
||||
"""
|
||||
cameras = {}
|
||||
video_dir = Path(cfg.VIDEO_DIR)
|
||||
|
||||
if not video_dir.exists():
|
||||
logger.error(f"Video directory not found: {cfg.VIDEO_DIR}")
|
||||
sys.exit(1)
|
||||
|
||||
for cam_dir in sorted(video_dir.iterdir()):
|
||||
if cam_dir.is_dir():
|
||||
videos = []
|
||||
for ext in cfg.VIDEO_EXTENSIONS:
|
||||
videos.extend(glob.glob(str(cam_dir / f"*{ext}")))
|
||||
videos.extend(glob.glob(str(cam_dir / f"*{ext.upper()}")))
|
||||
videos = sorted(set(videos))
|
||||
if videos:
|
||||
cameras[cam_dir.name] = videos
|
||||
logger.info(f"Camera '{cam_dir.name}': {len(videos)} videos")
|
||||
|
||||
if not cameras:
|
||||
logger.error("No video files found in any camera directory!")
|
||||
sys.exit(1)
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
def sanitize_camera_name(cam_name: str) -> str:
|
||||
"""Create a filesystem-safe camera identifier."""
|
||||
return cam_name.replace(" ", "_").replace("-", "_").replace("__", "_").strip("_").lower()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Adaptive Sampler State Machine
|
||||
# ──────────────────────────────────────────────
|
||||
class AdaptiveSampler:
|
||||
"""
|
||||
State machine for adaptive frame sampling.
|
||||
|
||||
States:
|
||||
- NORMAL: sample at BASE_FPS (1 fps)
|
||||
- HIGH: sample at HIGH_FPS (3 fps) — person recently detected
|
||||
- LOW: sample at LOW_FPS (0.5 fps) — long idle period
|
||||
"""
|
||||
|
||||
def __init__(self, video_fps: float):
|
||||
self.video_fps = video_fps
|
||||
self.state = "NORMAL"
|
||||
self.current_sample_fps = cfg.BASE_FPS
|
||||
self.last_person_time = 0.0 # video timestamp of last person detection
|
||||
self.no_person_streak = 0.0 # seconds since last person
|
||||
|
||||
def update(self, timestamp: float, person_detected: bool):
|
||||
"""Update state based on detection result at given video timestamp."""
|
||||
if person_detected:
|
||||
self.last_person_time = timestamp
|
||||
self.no_person_streak = 0.0
|
||||
self.state = "HIGH"
|
||||
self.current_sample_fps = cfg.HIGH_FPS
|
||||
else:
|
||||
self.no_person_streak = timestamp - self.last_person_time
|
||||
|
||||
if self.state == "HIGH":
|
||||
# Stay high for HIGH_FPS_DURATION after last detection
|
||||
if self.no_person_streak > cfg.HIGH_FPS_DURATION:
|
||||
self.state = "NORMAL"
|
||||
self.current_sample_fps = cfg.BASE_FPS
|
||||
|
||||
elif self.state == "NORMAL":
|
||||
if self.no_person_streak > cfg.LOW_FPS_THRESHOLD:
|
||||
self.state = "LOW"
|
||||
self.current_sample_fps = cfg.LOW_FPS
|
||||
|
||||
# LOW stays LOW until a person is detected again
|
||||
|
||||
def get_frame_interval(self) -> int:
|
||||
"""Return the frame interval (number of video frames to skip between samples)."""
|
||||
interval = max(1, int(self.video_fps / self.current_sample_fps))
|
||||
return interval
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Core Extraction
|
||||
# ──────────────────────────────────────────────
|
||||
def process_video(
|
||||
model: YOLO,
|
||||
video_path: str,
|
||||
camera_name: str,
|
||||
output_images_dir: str,
|
||||
output_labels_dir: str,
|
||||
global_stats: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Process a single video file: extract frames with adaptive sampling,
|
||||
detect persons with YOLO, save frames and labels.
|
||||
|
||||
Returns per-video stats dict.
|
||||
"""
|
||||
cam_safe = sanitize_camera_name(camera_name)
|
||||
video_name = Path(video_path).stem
|
||||
|
||||
# Check if already done
|
||||
checkpoint = load_checkpoint(video_path)
|
||||
if checkpoint and checkpoint.get("done"):
|
||||
logger.info(f" ⏭ Skipping (already done): {video_name}")
|
||||
return checkpoint
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
logger.error(f" ✗ Cannot open video: {video_path}")
|
||||
return {"error": "cannot_open", "video": video_name}
|
||||
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration_sec = total_frames / video_fps if video_fps > 0 else 0
|
||||
|
||||
logger.info(f" ▶ Processing: {video_name}")
|
||||
logger.info(f" FPS: {video_fps:.1f}, Frames: {total_frames}, Duration: {duration_sec:.0f}s")
|
||||
|
||||
# Resume from checkpoint
|
||||
start_frame = 0
|
||||
frame_counter = 0
|
||||
if checkpoint:
|
||||
start_frame = checkpoint.get("last_frame", 0)
|
||||
frame_counter = checkpoint.get("frame_counter", 0)
|
||||
logger.info(f" Resuming from frame {start_frame}")
|
||||
|
||||
sampler = AdaptiveSampler(video_fps)
|
||||
stats = {
|
||||
"video": video_name,
|
||||
"camera": camera_name,
|
||||
"total_frames": total_frames,
|
||||
"frames_extracted": 0,
|
||||
"frames_with_person": 0,
|
||||
"frames_without_person": 0,
|
||||
"last_frame": start_frame,
|
||||
"frame_counter": frame_counter,
|
||||
"done": False,
|
||||
}
|
||||
|
||||
# Seek to start position if resuming
|
||||
if start_frame > 0:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
current_frame_idx = start_frame
|
||||
batch_frames = [] # [(frame_idx, frame_array, timestamp), ...]
|
||||
batch_save_info = [] # [(output_img_path, output_lbl_path), ...]
|
||||
|
||||
last_checkpoint_time = time.time()
|
||||
|
||||
while True:
|
||||
# Check dataset size cap periodically
|
||||
if stats["frames_extracted"] % 500 == 0 and stats["frames_extracted"] > 0:
|
||||
current_size = get_dataset_size_gb(cfg.DATASET_DIR)
|
||||
if current_size >= cfg.MAX_DATASET_SIZE_GB:
|
||||
logger.warning(f" ⚠ Dataset size cap reached ({current_size:.1f} GB). Stopping extraction.")
|
||||
global_stats["size_cap_reached"] = True
|
||||
break
|
||||
|
||||
# Calculate next frame to sample
|
||||
interval = sampler.get_frame_interval()
|
||||
target_frame = current_frame_idx + interval
|
||||
|
||||
if target_frame >= total_frames:
|
||||
break
|
||||
|
||||
# Seek to target frame
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
current_frame_idx = target_frame
|
||||
timestamp = current_frame_idx / video_fps
|
||||
|
||||
# Build unique output filename
|
||||
frame_counter += 1
|
||||
frame_id = f"{cam_safe}_{frame_counter:06d}"
|
||||
img_path = os.path.join(output_images_dir, f"{frame_id}.jpg")
|
||||
lbl_path = os.path.join(output_labels_dir, f"{frame_id}.txt")
|
||||
|
||||
batch_frames.append((current_frame_idx, frame, timestamp))
|
||||
batch_save_info.append((img_path, lbl_path))
|
||||
|
||||
# Process batch when full
|
||||
if len(batch_frames) >= cfg.BATCH_SIZE:
|
||||
persons_in_batch = _process_batch(
|
||||
model, batch_frames, batch_save_info, sampler, stats
|
||||
)
|
||||
batch_frames.clear()
|
||||
batch_save_info.clear()
|
||||
|
||||
# Checkpoint every 30 seconds
|
||||
if time.time() - last_checkpoint_time > 30:
|
||||
stats["last_frame"] = current_frame_idx
|
||||
stats["frame_counter"] = frame_counter
|
||||
save_checkpoint(video_path, stats)
|
||||
last_checkpoint_time = time.time()
|
||||
|
||||
# Progress logging every 1000 extractions
|
||||
if stats["frames_extracted"] % 1000 == 0 and stats["frames_extracted"] > 0:
|
||||
pct = (current_frame_idx / total_frames) * 100
|
||||
logger.info(
|
||||
f" Progress: {pct:.1f}% | Extracted: {stats['frames_extracted']} | "
|
||||
f"Persons: {stats['frames_with_person']} | Mode: {sampler.state} | "
|
||||
f"FPS: {sampler.current_sample_fps}"
|
||||
)
|
||||
|
||||
# Process remaining batch
|
||||
if batch_frames:
|
||||
_process_batch(model, batch_frames, batch_save_info, sampler, stats)
|
||||
|
||||
cap.release()
|
||||
|
||||
stats["last_frame"] = current_frame_idx
|
||||
stats["frame_counter"] = frame_counter
|
||||
mark_video_done(video_path, stats)
|
||||
|
||||
logger.info(
|
||||
f" ✓ Done: {video_name} | Extracted: {stats['frames_extracted']} | "
|
||||
f"With person: {stats['frames_with_person']} | "
|
||||
f"Without: {stats['frames_without_person']}"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def _process_batch(
|
||||
model: YOLO,
|
||||
batch_frames: list,
|
||||
batch_save_info: list,
|
||||
sampler: AdaptiveSampler,
|
||||
stats: dict,
|
||||
) -> int:
|
||||
"""
|
||||
Run YOLO inference on a batch of frames, save images + labels.
|
||||
Returns number of frames with person detections.
|
||||
"""
|
||||
frames = [f[1] for f in batch_frames]
|
||||
timestamps = [f[2] for f in batch_frames]
|
||||
|
||||
# Run YOLO batch inference on GPU
|
||||
results = model.predict(
|
||||
source=frames,
|
||||
conf=cfg.DETECTION_CONF,
|
||||
iou=cfg.DETECTION_IOU,
|
||||
classes=[cfg.PERSON_CLASS_ID],
|
||||
device=0,
|
||||
verbose=False,
|
||||
half=True, # FP16 for speed
|
||||
)
|
||||
|
||||
persons_count = 0
|
||||
|
||||
for i, result in enumerate(results):
|
||||
img_path, lbl_path = batch_save_info[i]
|
||||
timestamp = timestamps[i]
|
||||
frame = frames[i]
|
||||
|
||||
# Get person detections
|
||||
boxes = result.boxes
|
||||
person_boxes = boxes[boxes.cls == cfg.PERSON_CLASS_ID]
|
||||
has_person = len(person_boxes) > 0
|
||||
|
||||
# Update adaptive sampler
|
||||
sampler.update(timestamp, has_person)
|
||||
|
||||
# Save frame as JPEG
|
||||
cv2.imwrite(img_path, frame, [cv2.IMWRITE_JPEG_QUALITY, cfg.JPEG_QUALITY])
|
||||
|
||||
# Save YOLO-format labels
|
||||
h, w = frame.shape[:2]
|
||||
with open(lbl_path, "w") as f:
|
||||
for box in person_boxes:
|
||||
# Convert to YOLO format: class x_center y_center width height (normalized)
|
||||
xyxy = box.xyxy[0].cpu().numpy()
|
||||
x_center = ((xyxy[0] + xyxy[2]) / 2) / w
|
||||
y_center = ((xyxy[1] + xyxy[3]) / 2) / h
|
||||
box_w = (xyxy[2] - xyxy[0]) / w
|
||||
box_h = (xyxy[3] - xyxy[1]) / h
|
||||
conf = float(box.conf[0])
|
||||
f.write(f"0 {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}\n")
|
||||
|
||||
stats["frames_extracted"] += 1
|
||||
if has_person:
|
||||
stats["frames_with_person"] += 1
|
||||
persons_count += 1
|
||||
else:
|
||||
stats["frames_without_person"] += 1
|
||||
|
||||
return persons_count
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Main
|
||||
# ──────────────────────────────────────────────
|
||||
def extract_all() -> dict:
|
||||
"""
|
||||
Main entry point: discover videos, extract dataset.
|
||||
Returns {camera_name: [video_stats]}.
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("DATASET EXTRACTION STARTED")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Load YOLO model on GPU
|
||||
logger.info(f"Loading detector model: {cfg.DETECTOR_MODEL}")
|
||||
model = YOLO(cfg.DETECTOR_MODEL)
|
||||
model.to("cuda")
|
||||
logger.info("Model loaded on GPU ✓")
|
||||
|
||||
# Discover cameras and videos
|
||||
cameras = discover_videos()
|
||||
total_videos = sum(len(v) for v in cameras.values())
|
||||
logger.info(f"Found {len(cameras)} cameras, {total_videos} videos total")
|
||||
|
||||
# Create output directories (flat — split happens later)
|
||||
images_dir = os.path.join(cfg.DATASET_DIR, "images", "all")
|
||||
labels_dir = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
os.makedirs(labels_dir, exist_ok=True)
|
||||
|
||||
# Save camera mapping for split script
|
||||
camera_mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
|
||||
|
||||
# Load existing mapping if resuming
|
||||
if os.path.exists(camera_mapping_path):
|
||||
with open(camera_mapping_path, "r") as f:
|
||||
camera_mapping = json.load(f)
|
||||
else:
|
||||
camera_mapping = {}
|
||||
|
||||
global_stats = {"size_cap_reached": False}
|
||||
all_stats = {}
|
||||
video_num = 0
|
||||
|
||||
for cam_name, video_list in cameras.items():
|
||||
cam_safe = sanitize_camera_name(cam_name)
|
||||
logger.info(f"\n{'─' * 40}")
|
||||
logger.info(f"Camera: {cam_name} ({len(video_list)} videos)")
|
||||
logger.info(f"{'─' * 40}")
|
||||
|
||||
cam_stats = []
|
||||
for video_path in video_list:
|
||||
video_num += 1
|
||||
logger.info(f"\n[{video_num}/{total_videos}]")
|
||||
|
||||
if global_stats["size_cap_reached"]:
|
||||
logger.warning("Size cap reached — skipping remaining videos.")
|
||||
break
|
||||
|
||||
vstats = process_video(
|
||||
model, video_path, cam_name,
|
||||
images_dir, labels_dir, global_stats
|
||||
)
|
||||
cam_stats.append(vstats)
|
||||
|
||||
# Track which frames belong to which camera
|
||||
if cam_name not in camera_mapping:
|
||||
camera_mapping[cam_name] = {"safe_name": cam_safe, "frames": []}
|
||||
|
||||
# Collect frame IDs for this camera
|
||||
frame_prefix = cam_safe + "_"
|
||||
existing_frames = camera_mapping[cam_name].get("frames", [])
|
||||
new_frames = [
|
||||
f for f in os.listdir(images_dir)
|
||||
if f.startswith(frame_prefix) and f.endswith(".jpg")
|
||||
]
|
||||
camera_mapping[cam_name]["frames"] = sorted(set(existing_frames + new_frames))
|
||||
|
||||
# Save mapping after each video
|
||||
with open(camera_mapping_path, "w") as f:
|
||||
json.dump(camera_mapping, f, indent=2)
|
||||
|
||||
all_stats[cam_name] = cam_stats
|
||||
|
||||
if global_stats["size_cap_reached"]:
|
||||
break
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("EXTRACTION COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
total_extracted = 0
|
||||
total_persons = 0
|
||||
for cam, stats_list in all_stats.items():
|
||||
for s in stats_list:
|
||||
total_extracted += s.get("frames_extracted", 0)
|
||||
total_persons += s.get("frames_with_person", 0)
|
||||
|
||||
dataset_size = get_dataset_size_gb(cfg.DATASET_DIR)
|
||||
logger.info(f"Total frames extracted: {total_extracted}")
|
||||
logger.info(f"Frames with persons: {total_persons}")
|
||||
logger.info(f"Dataset size: {dataset_size:.2f} GB")
|
||||
logger.info(f"Cameras processed: {len(cameras)}")
|
||||
|
||||
# Save final stats
|
||||
stats_path = os.path.join(cfg.DATASET_DIR, "extraction_stats.json")
|
||||
with open(stats_path, "w") as f:
|
||||
json.dump({
|
||||
"total_extracted": total_extracted,
|
||||
"total_with_persons": total_persons,
|
||||
"dataset_size_gb": round(dataset_size, 2),
|
||||
"cameras": {k: len(v) for k, v in cameras.items()},
|
||||
"all_stats": all_stats,
|
||||
}, f, indent=2, default=str)
|
||||
|
||||
return all_stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
extract_all()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""
|
||||
Centralized configuration for the person detection dataset pipeline.
|
||||
All tunable parameters are defined here.
|
||||
"""
|
||||
import os
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Paths
|
||||
# ──────────────────────────────────────────────
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
VIDEO_DIR = os.path.join(BASE_DIR, "video_data")
|
||||
DATASET_DIR = os.path.join(BASE_DIR, "dataset")
|
||||
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
|
||||
LOG_DIR = os.path.join(BASE_DIR, "logs")
|
||||
|
||||
# Model paths
|
||||
DETECTOR_MODEL = os.path.join(BASE_DIR, "yolo26x.pt") # Large model for auto-labeling
|
||||
TRAIN_MODEL = "yolo26n.pt" # Nano model to train (auto-downloads)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Dataset Extraction
|
||||
# ──────────────────────────────────────────────
|
||||
MAX_DATASET_SIZE_GB = 50 # Stop extraction if dataset exceeds this
|
||||
JPEG_QUALITY = 85 # JPEG save quality (1-100)
|
||||
DETECTION_CONF = 0.35 # Min confidence for person detection
|
||||
DETECTION_IOU = 0.45 # NMS IoU threshold
|
||||
BATCH_SIZE = 16 # Frames per YOLO inference batch
|
||||
PERSON_CLASS_ID = 0 # YOLO class ID for "person"
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Adaptive Sampling
|
||||
# ──────────────────────────────────────────────
|
||||
BASE_FPS = 1.0 # Default: 1 frame per second
|
||||
HIGH_FPS = 3.0 # When person detected: 3 frames per second
|
||||
LOW_FPS = 0.5 # When idle (no person): 0.5 frames per second
|
||||
HIGH_FPS_DURATION = 5 # Seconds to stay at high FPS after person detected
|
||||
LOW_FPS_THRESHOLD = 10 # Seconds without person before dropping to low FPS
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Train/Test Split (Camera-Level)
|
||||
# ──────────────────────────────────────────────
|
||||
TEST_CAMERAS = 4 # Number of cameras to hold out for testing
|
||||
RANDOM_SEED = 42 # For reproducible camera selection
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Training
|
||||
# ──────────────────────────────────────────────
|
||||
TRAIN_EPOCHS = 100
|
||||
TRAIN_BATCH = 16 # Batch size for training (adjust for VRAM)
|
||||
TRAIN_IMGSZ = 640 # Training image size
|
||||
EARLY_STOP_PATIENCE = 15 # Stop if no improvement for N epochs
|
||||
TRAIN_WORKERS = 8 # DataLoader workers
|
||||
TRAIN_PROJECT = os.path.join(BASE_DIR, "runs", "detect")
|
||||
TRAIN_NAME = "person_detection"
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Video file extensions to process
|
||||
# ──────────────────────────────────────────────
|
||||
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv"}
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
run_pipeline.py
|
||||
───────────────
|
||||
Master pipeline script that orchestrates the full workflow:
|
||||
|
||||
Phase 1: Extract frames from videos (adaptive sampling + YOLO detection)
|
||||
Phase 2: Split dataset by camera (train/test)
|
||||
Phase 3: Train yolo26n.pt on the dataset
|
||||
|
||||
Designed to run overnight — resumable from any phase.
|
||||
|
||||
Usage:
|
||||
python run_pipeline.py # Run full pipeline
|
||||
python run_pipeline.py --phase 2 # Start from phase 2 (skip extraction)
|
||||
python run_pipeline.py --phase 3 # Start from phase 3 (skip extract + split)
|
||||
python run_pipeline.py --extract-only # Only extract (no split or train)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
import pipeline_config as cfg
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Logging
|
||||
# ──────────────────────────────────────────────
|
||||
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||
log_file = os.path.join(cfg.LOG_DIR, f"pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def phase1_extract():
|
||||
"""Phase 1: Extract dataset from videos."""
|
||||
logger.info("\n" + "█" * 60)
|
||||
logger.info("█ PHASE 1: DATASET EXTRACTION")
|
||||
logger.info("█" * 60 + "\n")
|
||||
|
||||
from extract_dataset import extract_all
|
||||
stats = extract_all()
|
||||
return stats
|
||||
|
||||
|
||||
def phase2_split():
|
||||
"""Phase 2: Split dataset by camera."""
|
||||
logger.info("\n" + "█" * 60)
|
||||
logger.info("█ PHASE 2: CAMERA-LEVEL TRAIN/TEST SPLIT")
|
||||
logger.info("█" * 60 + "\n")
|
||||
|
||||
from split_dataset import split_dataset
|
||||
split_info = split_dataset()
|
||||
return split_info
|
||||
|
||||
|
||||
def phase3_train():
|
||||
"""Phase 3: Train model."""
|
||||
logger.info("\n" + "█" * 60)
|
||||
logger.info("█ PHASE 3: MODEL TRAINING")
|
||||
logger.info("█" * 60 + "\n")
|
||||
|
||||
from train_model import train_model
|
||||
best_weights = train_model()
|
||||
return best_weights
|
||||
|
||||
|
||||
def run_pipeline(start_phase: int = 1, extract_only: bool = False):
|
||||
"""Run the full pipeline from the specified starting phase."""
|
||||
pipeline_start = datetime.now()
|
||||
|
||||
logger.info("╔" + "═" * 58 + "╗")
|
||||
logger.info("║ PERSON DETECTION PIPELINE ║")
|
||||
logger.info("║ " + f"Started: {pipeline_start.strftime('%Y-%m-%d %H:%M:%S')}".ljust(57) + "║")
|
||||
logger.info("╚" + "═" * 58 + "╝")
|
||||
logger.info("")
|
||||
logger.info(f"Configuration:")
|
||||
logger.info(f" Video directory: {cfg.VIDEO_DIR}")
|
||||
logger.info(f" Dataset output: {cfg.DATASET_DIR}")
|
||||
logger.info(f" Detector model: {cfg.DETECTOR_MODEL}")
|
||||
logger.info(f" Training model: {cfg.TRAIN_MODEL}")
|
||||
logger.info(f" Max dataset size: {cfg.MAX_DATASET_SIZE_GB} GB")
|
||||
logger.info(f" Starting phase: {start_phase}")
|
||||
logger.info(f" Extract only: {extract_only}")
|
||||
logger.info("")
|
||||
|
||||
try:
|
||||
# Phase 1: Extract
|
||||
if start_phase <= 1:
|
||||
p1_start = datetime.now()
|
||||
phase1_extract()
|
||||
p1_duration = datetime.now() - p1_start
|
||||
logger.info(f"\nPhase 1 completed in {p1_duration}")
|
||||
|
||||
if extract_only:
|
||||
logger.info("Extract-only mode — stopping after Phase 1.")
|
||||
return
|
||||
|
||||
# Phase 2: Split
|
||||
if start_phase <= 2:
|
||||
p2_start = datetime.now()
|
||||
phase2_split()
|
||||
p2_duration = datetime.now() - p2_start
|
||||
logger.info(f"\nPhase 2 completed in {p2_duration}")
|
||||
|
||||
# Phase 3: Train
|
||||
if start_phase <= 3:
|
||||
p3_start = datetime.now()
|
||||
best_weights = phase3_train()
|
||||
p3_duration = datetime.now() - p3_start
|
||||
logger.info(f"\nPhase 3 completed in {p3_duration}")
|
||||
|
||||
# Final summary
|
||||
pipeline_end = datetime.now()
|
||||
total_duration = pipeline_end - pipeline_start
|
||||
|
||||
logger.info("\n" + "╔" + "═" * 58 + "╗")
|
||||
logger.info("║ PIPELINE COMPLETED SUCCESSFULLY ║")
|
||||
logger.info("║ " + f"Duration: {total_duration}".ljust(57) + "║")
|
||||
logger.info("╚" + "═" * 58 + "╝")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("\n\nPipeline interrupted by user. Progress has been checkpointed.")
|
||||
logger.warning("Re-run to resume from where you left off.")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"\n\nPipeline failed with error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("Progress has been checkpointed. Fix the error and re-run.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Person detection dataset pipeline",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python run_pipeline.py Run full pipeline
|
||||
python run_pipeline.py --phase 2 Start from split phase
|
||||
python run_pipeline.py --phase 3 Start from training phase
|
||||
python run_pipeline.py --extract-only Only extract dataset
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phase",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2, 3],
|
||||
help="Starting phase (1=extract, 2=split, 3=train)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract-only",
|
||||
action="store_true",
|
||||
help="Only run extraction phase",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_pipeline(start_phase=args.phase, extract_only=args.extract_only)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
"""
|
||||
split_dataset.py
|
||||
────────────────
|
||||
Camera-level train/test split for the extracted dataset.
|
||||
|
||||
Splits ENTIRE cameras into train or test sets to prevent data leakage
|
||||
from similar/consecutive surveillance frames.
|
||||
|
||||
- If >= 5 cameras: hold out 4 cameras for test
|
||||
- If < 5 cameras: hold out 1 camera for test
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import shutil
|
||||
import random
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import pipeline_config as cfg
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Logging
|
||||
# ──────────────────────────────────────────────
|
||||
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||
log_file = os.path.join(cfg.LOG_DIR, f"split_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_camera_mapping() -> dict:
|
||||
"""Load the camera→frames mapping created during extraction."""
|
||||
mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
|
||||
if not os.path.exists(mapping_path):
|
||||
logger.error(f"Camera mapping not found: {mapping_path}")
|
||||
logger.error("Run extract_dataset.py first!")
|
||||
sys.exit(1)
|
||||
|
||||
with open(mapping_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def split_cameras(camera_names: list[str]) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Split camera names into train and test sets.
|
||||
Returns (train_cameras, test_cameras).
|
||||
"""
|
||||
random.seed(cfg.RANDOM_SEED)
|
||||
|
||||
n_cameras = len(camera_names)
|
||||
if n_cameras < 2:
|
||||
logger.error(f"Need at least 2 cameras for splitting, found {n_cameras}")
|
||||
sys.exit(1)
|
||||
|
||||
# Determine how many cameras to hold out
|
||||
if n_cameras >= 5:
|
||||
n_test = min(cfg.TEST_CAMERAS, n_cameras - 1) # At least 1 for train
|
||||
else:
|
||||
n_test = 1
|
||||
|
||||
shuffled = camera_names.copy()
|
||||
random.shuffle(shuffled)
|
||||
|
||||
test_cameras = shuffled[:n_test]
|
||||
train_cameras = shuffled[n_test:]
|
||||
|
||||
return train_cameras, test_cameras
|
||||
|
||||
|
||||
def move_frames(
|
||||
camera_mapping: dict,
|
||||
camera_names: list[str],
|
||||
split_name: str,
|
||||
src_images_dir: str,
|
||||
src_labels_dir: str,
|
||||
):
|
||||
"""Move frames belonging to given cameras into the split directory."""
|
||||
dst_images = os.path.join(cfg.DATASET_DIR, "images", split_name)
|
||||
dst_labels = os.path.join(cfg.DATASET_DIR, "labels", split_name)
|
||||
os.makedirs(dst_images, exist_ok=True)
|
||||
os.makedirs(dst_labels, exist_ok=True)
|
||||
|
||||
total_moved = 0
|
||||
|
||||
for cam_name in camera_names:
|
||||
cam_data = camera_mapping[cam_name]
|
||||
frames = cam_data.get("frames", [])
|
||||
|
||||
for img_filename in frames:
|
||||
# Image
|
||||
src_img = os.path.join(src_images_dir, img_filename)
|
||||
dst_img = os.path.join(dst_images, img_filename)
|
||||
if os.path.exists(src_img):
|
||||
shutil.move(src_img, dst_img)
|
||||
|
||||
# Label
|
||||
lbl_filename = img_filename.replace(".jpg", ".txt")
|
||||
src_lbl = os.path.join(src_labels_dir, lbl_filename)
|
||||
dst_lbl = os.path.join(dst_labels, lbl_filename)
|
||||
if os.path.exists(src_lbl):
|
||||
shutil.move(src_lbl, dst_lbl)
|
||||
|
||||
total_moved += 1
|
||||
|
||||
logger.info(f" {split_name}: {total_moved} frames from {len(camera_names)} cameras")
|
||||
return total_moved
|
||||
|
||||
|
||||
def create_dataset_yaml(train_cameras: list[str], test_cameras: list[str]):
|
||||
"""Create dataset.yaml for YOLO training."""
|
||||
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
|
||||
|
||||
# Use forward slashes for YOLO compatibility
|
||||
dataset_path = cfg.DATASET_DIR.replace("\\", "/")
|
||||
|
||||
content = f"""# Auto-generated dataset config for person detection
|
||||
# Generated: {datetime.now().isoformat()}
|
||||
#
|
||||
# Train cameras: {', '.join(train_cameras)}
|
||||
# Test cameras: {', '.join(test_cameras)}
|
||||
|
||||
path: {dataset_path}
|
||||
train: images/train
|
||||
val: images/test
|
||||
|
||||
nc: 1
|
||||
names: ['person']
|
||||
"""
|
||||
|
||||
with open(yaml_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"Created dataset.yaml at {yaml_path}")
|
||||
return yaml_path
|
||||
|
||||
|
||||
def split_dataset() -> dict:
|
||||
"""
|
||||
Main entry point: split extracted dataset by camera.
|
||||
Returns split info dict.
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("DATASET SPLIT (CAMERA-LEVEL)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Load camera mapping
|
||||
camera_mapping = load_camera_mapping()
|
||||
camera_names = sorted(camera_mapping.keys())
|
||||
logger.info(f"Found {len(camera_names)} cameras: {camera_names}")
|
||||
|
||||
# Count frames per camera
|
||||
for cam in camera_names:
|
||||
n_frames = len(camera_mapping[cam].get("frames", []))
|
||||
logger.info(f" {cam}: {n_frames} frames")
|
||||
|
||||
# Split cameras
|
||||
train_cameras, test_cameras = split_cameras(camera_names)
|
||||
logger.info(f"\nTrain cameras ({len(train_cameras)}): {train_cameras}")
|
||||
logger.info(f"Test cameras ({len(test_cameras)}): {test_cameras}")
|
||||
|
||||
# Source directories
|
||||
src_images = os.path.join(cfg.DATASET_DIR, "images", "all")
|
||||
src_labels = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
||||
|
||||
if not os.path.exists(src_images):
|
||||
logger.error(f"Source images directory not found: {src_images}")
|
||||
sys.exit(1)
|
||||
|
||||
# Move frames to train/test directories
|
||||
logger.info("\nMoving frames to split directories...")
|
||||
n_train = move_frames(camera_mapping, train_cameras, "train", src_images, src_labels)
|
||||
n_test = move_frames(camera_mapping, test_cameras, "test", src_images, src_labels)
|
||||
|
||||
# Clean up empty 'all' directory
|
||||
try:
|
||||
remaining_imgs = os.listdir(src_images)
|
||||
if not remaining_imgs:
|
||||
os.rmdir(src_images)
|
||||
src_labels_check = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
||||
remaining_lbls = os.listdir(src_labels_check)
|
||||
if not remaining_lbls:
|
||||
os.rmdir(src_labels_check)
|
||||
logger.info("Cleaned up empty 'all' directories")
|
||||
else:
|
||||
logger.warning(f"{len(remaining_imgs)} orphan images left in 'all' directory")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cleanup note: {e}")
|
||||
|
||||
# Create dataset.yaml
|
||||
yaml_path = create_dataset_yaml(train_cameras, test_cameras)
|
||||
|
||||
# Create classes.txt
|
||||
classes_path = os.path.join(cfg.DATASET_DIR, "classes.txt")
|
||||
with open(classes_path, "w") as f:
|
||||
f.write("person\n")
|
||||
|
||||
# Save split info
|
||||
split_info = {
|
||||
"train_cameras": train_cameras,
|
||||
"test_cameras": test_cameras,
|
||||
"train_frames": n_train,
|
||||
"test_frames": n_test,
|
||||
"random_seed": cfg.RANDOM_SEED,
|
||||
"yaml_path": yaml_path,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
split_info_path = os.path.join(cfg.DATASET_DIR, "split_info.json")
|
||||
with open(split_info_path, "w") as f:
|
||||
json.dump(split_info, f, indent=2)
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("SPLIT COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Train: {n_train} frames from {len(train_cameras)} cameras")
|
||||
logger.info(f"Test: {n_test} frames from {len(test_cameras)} cameras")
|
||||
logger.info(f"Ratio: {n_train/(n_train+n_test)*100:.1f}% / {n_test/(n_train+n_test)*100:.1f}%")
|
||||
|
||||
return split_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
split_dataset()
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
"""
|
||||
train_model.py
|
||||
──────────────
|
||||
Fine-tune yolo26n.pt on the extracted person detection dataset.
|
||||
|
||||
Optimized for NVIDIA RTX A5000 (16GB VRAM):
|
||||
- Mixed precision (AMP) enabled
|
||||
- Batch size 16, image size 640
|
||||
- Early stopping with patience 15
|
||||
- Full YOLO augmentation pipeline
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
import pipeline_config as cfg
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Logging
|
||||
# ──────────────────────────────────────────────
|
||||
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||
log_file = os.path.join(cfg.LOG_DIR, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def verify_dataset() -> str:
|
||||
"""Verify dataset structure and return path to dataset.yaml."""
|
||||
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
|
||||
if not os.path.exists(yaml_path):
|
||||
logger.error(f"dataset.yaml not found at {yaml_path}")
|
||||
logger.error("Run extract_dataset.py and split_dataset.py first!")
|
||||
sys.exit(1)
|
||||
|
||||
# Check directories exist
|
||||
train_images = os.path.join(cfg.DATASET_DIR, "images", "train")
|
||||
test_images = os.path.join(cfg.DATASET_DIR, "images", "test")
|
||||
|
||||
if not os.path.exists(train_images):
|
||||
logger.error(f"Train images directory not found: {train_images}")
|
||||
sys.exit(1)
|
||||
if not os.path.exists(test_images):
|
||||
logger.error(f"Test images directory not found: {test_images}")
|
||||
sys.exit(1)
|
||||
|
||||
n_train = len([f for f in os.listdir(train_images) if f.endswith(".jpg")])
|
||||
n_test = len([f for f in os.listdir(test_images) if f.endswith(".jpg")])
|
||||
|
||||
logger.info(f"Dataset verified:")
|
||||
logger.info(f" Train images: {n_train}")
|
||||
logger.info(f" Test images: {n_test}")
|
||||
|
||||
if n_train == 0:
|
||||
logger.error("No training images found!")
|
||||
sys.exit(1)
|
||||
if n_test == 0:
|
||||
logger.warning("No test images found — training will proceed without validation.")
|
||||
|
||||
return yaml_path
|
||||
|
||||
|
||||
def train_model() -> str:
|
||||
"""
|
||||
Main entry point: train yolo26n.pt on the person detection dataset.
|
||||
Returns path to best weights.
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("MODEL TRAINING STARTED")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Verify dataset
|
||||
yaml_path = verify_dataset()
|
||||
logger.info(f"Dataset config: {yaml_path}")
|
||||
|
||||
# Load base model
|
||||
logger.info(f"Loading base model: {cfg.TRAIN_MODEL}")
|
||||
model = YOLO(cfg.TRAIN_MODEL)
|
||||
logger.info("Base model loaded ✓")
|
||||
|
||||
# Training configuration
|
||||
train_args = {
|
||||
"data": yaml_path,
|
||||
"epochs": cfg.TRAIN_EPOCHS,
|
||||
"batch": cfg.TRAIN_BATCH,
|
||||
"imgsz": cfg.TRAIN_IMGSZ,
|
||||
"device": 0, # GPU
|
||||
"workers": cfg.TRAIN_WORKERS,
|
||||
"patience": cfg.EARLY_STOP_PATIENCE,
|
||||
"project": cfg.TRAIN_PROJECT,
|
||||
"name": cfg.TRAIN_NAME,
|
||||
"exist_ok": True, # Overwrite previous run
|
||||
"pretrained": True,
|
||||
"save": True,
|
||||
"save_period": 10, # Save checkpoint every 10 epochs
|
||||
"val": True,
|
||||
"plots": True, # Generate training plots
|
||||
"verbose": True,
|
||||
# Augmentation (YOLO defaults are good, but explicit for clarity)
|
||||
"hsv_h": 0.015,
|
||||
"hsv_s": 0.7,
|
||||
"hsv_v": 0.4,
|
||||
"degrees": 0.0,
|
||||
"translate": 0.1,
|
||||
"scale": 0.5,
|
||||
"shear": 0.0,
|
||||
"flipud": 0.0, # No vertical flip (people don't appear upside down)
|
||||
"fliplr": 0.5, # Horizontal flip is fine
|
||||
"mosaic": 1.0,
|
||||
"mixup": 0.1,
|
||||
}
|
||||
|
||||
logger.info("Training configuration:")
|
||||
for k, v in train_args.items():
|
||||
logger.info(f" {k}: {v}")
|
||||
|
||||
# Start training
|
||||
logger.info("\n" + "─" * 40)
|
||||
logger.info("TRAINING IN PROGRESS...")
|
||||
logger.info("─" * 40)
|
||||
|
||||
start_time = datetime.now()
|
||||
results = model.train(**train_args)
|
||||
end_time = datetime.now()
|
||||
duration = end_time - start_time
|
||||
|
||||
# Results
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TRAINING COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Duration: {duration}")
|
||||
|
||||
# Find best weights
|
||||
best_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "best.pt")
|
||||
last_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "last.pt")
|
||||
|
||||
if os.path.exists(best_weights):
|
||||
logger.info(f"Best weights: {best_weights}")
|
||||
# Copy best weights to project root for easy access
|
||||
import shutil
|
||||
output_model = os.path.join(cfg.BASE_DIR, "person_detector_best.pt")
|
||||
shutil.copy2(best_weights, output_model)
|
||||
logger.info(f"Copied best model to: {output_model}")
|
||||
else:
|
||||
logger.warning(f"Best weights not found at expected path: {best_weights}")
|
||||
best_weights = last_weights
|
||||
|
||||
# Save training summary
|
||||
summary = {
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"duration_seconds": duration.total_seconds(),
|
||||
"best_weights": best_weights,
|
||||
"train_args": {k: str(v) for k, v in train_args.items()},
|
||||
}
|
||||
|
||||
summary_path = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "training_summary.json")
|
||||
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
|
||||
with open(summary_path, "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
return best_weights
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_model()
|
||||
Loading…
Reference in New Issue