feat: implement adaptive video frame extraction pipeline with YOLO-based auto-labeling and checkpointing

main
bahawal.baloch 2026-04-09 18:47:04 +05:00
parent a1da64e017
commit a7184a5773
6 changed files with 1149 additions and 1 deletions

4
.gitignore vendored
View File

@ -1,4 +1,6 @@
.env
*.pt
alerts/
__pycache__
__pycache__
dataset/
video_data/

504
extract_dataset.py Normal file
View File

@ -0,0 +1,504 @@
"""
extract_dataset.py
Extracts person-detection frames from surveillance videos using adaptive
frame sampling and yolo26x.pt for auto-labeling.
Features:
- Adaptive FPS: baseline 1 FPS, high 3 FPS (person), low 0.5 FPS (idle)
- GPU-accelerated YOLO inference in batches
- Per-video checkpointing for crash recovery
- 50 GB dataset size cap
- Organizes output in YOLO detection format
"""
import os
import sys
import cv2
import json
import time
import glob
import logging
import numpy as np
from pathlib import Path
from datetime import datetime
from ultralytics import YOLO
import pipeline_config as cfg
# ──────────────────────────────────────────────
# Logging
# ──────────────────────────────────────────────
os.makedirs(cfg.LOG_DIR, exist_ok=True)
os.makedirs(cfg.CHECKPOINT_DIR, exist_ok=True)
log_file = os.path.join(cfg.LOG_DIR, f"extract_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────
def get_dataset_size_gb(dataset_dir: str) -> float:
"""Return total size of the dataset directory in GB."""
total = 0
for dirpath, _, filenames in os.walk(dataset_dir):
for f in filenames:
total += os.path.getsize(os.path.join(dirpath, f))
return total / (1024 ** 3)
def get_checkpoint_path(video_path: str) -> str:
"""Return checkpoint file path for a given video."""
video_name = Path(video_path).stem
return os.path.join(cfg.CHECKPOINT_DIR, f"{video_name}.json")
def load_checkpoint(video_path: str) -> dict | None:
"""Load checkpoint for a video if it exists."""
cp_path = get_checkpoint_path(video_path)
if os.path.exists(cp_path):
with open(cp_path, "r") as f:
return json.load(f)
return None
def save_checkpoint(video_path: str, data: dict):
"""Save checkpoint for a video."""
cp_path = get_checkpoint_path(video_path)
with open(cp_path, "w") as f:
json.dump(data, f, indent=2)
def mark_video_done(video_path: str, stats: dict):
"""Mark a video as fully processed."""
stats["done"] = True
save_checkpoint(video_path, stats)
def discover_videos() -> dict[str, list[str]]:
"""
Discover all camera directories and their video files.
Returns {camera_name: [video_paths]}.
"""
cameras = {}
video_dir = Path(cfg.VIDEO_DIR)
if not video_dir.exists():
logger.error(f"Video directory not found: {cfg.VIDEO_DIR}")
sys.exit(1)
for cam_dir in sorted(video_dir.iterdir()):
if cam_dir.is_dir():
videos = []
for ext in cfg.VIDEO_EXTENSIONS:
videos.extend(glob.glob(str(cam_dir / f"*{ext}")))
videos.extend(glob.glob(str(cam_dir / f"*{ext.upper()}")))
videos = sorted(set(videos))
if videos:
cameras[cam_dir.name] = videos
logger.info(f"Camera '{cam_dir.name}': {len(videos)} videos")
if not cameras:
logger.error("No video files found in any camera directory!")
sys.exit(1)
return cameras
def sanitize_camera_name(cam_name: str) -> str:
"""Create a filesystem-safe camera identifier."""
return cam_name.replace(" ", "_").replace("-", "_").replace("__", "_").strip("_").lower()
# ──────────────────────────────────────────────
# Adaptive Sampler State Machine
# ──────────────────────────────────────────────
class AdaptiveSampler:
"""
State machine for adaptive frame sampling.
States:
- NORMAL: sample at BASE_FPS (1 fps)
- HIGH: sample at HIGH_FPS (3 fps) person recently detected
- LOW: sample at LOW_FPS (0.5 fps) long idle period
"""
def __init__(self, video_fps: float):
self.video_fps = video_fps
self.state = "NORMAL"
self.current_sample_fps = cfg.BASE_FPS
self.last_person_time = 0.0 # video timestamp of last person detection
self.no_person_streak = 0.0 # seconds since last person
def update(self, timestamp: float, person_detected: bool):
"""Update state based on detection result at given video timestamp."""
if person_detected:
self.last_person_time = timestamp
self.no_person_streak = 0.0
self.state = "HIGH"
self.current_sample_fps = cfg.HIGH_FPS
else:
self.no_person_streak = timestamp - self.last_person_time
if self.state == "HIGH":
# Stay high for HIGH_FPS_DURATION after last detection
if self.no_person_streak > cfg.HIGH_FPS_DURATION:
self.state = "NORMAL"
self.current_sample_fps = cfg.BASE_FPS
elif self.state == "NORMAL":
if self.no_person_streak > cfg.LOW_FPS_THRESHOLD:
self.state = "LOW"
self.current_sample_fps = cfg.LOW_FPS
# LOW stays LOW until a person is detected again
def get_frame_interval(self) -> int:
"""Return the frame interval (number of video frames to skip between samples)."""
interval = max(1, int(self.video_fps / self.current_sample_fps))
return interval
# ──────────────────────────────────────────────
# Core Extraction
# ──────────────────────────────────────────────
def process_video(
model: YOLO,
video_path: str,
camera_name: str,
output_images_dir: str,
output_labels_dir: str,
global_stats: dict,
) -> dict:
"""
Process a single video file: extract frames with adaptive sampling,
detect persons with YOLO, save frames and labels.
Returns per-video stats dict.
"""
cam_safe = sanitize_camera_name(camera_name)
video_name = Path(video_path).stem
# Check if already done
checkpoint = load_checkpoint(video_path)
if checkpoint and checkpoint.get("done"):
logger.info(f" ⏭ Skipping (already done): {video_name}")
return checkpoint
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f" ✗ Cannot open video: {video_path}")
return {"error": "cannot_open", "video": video_name}
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration_sec = total_frames / video_fps if video_fps > 0 else 0
logger.info(f" ▶ Processing: {video_name}")
logger.info(f" FPS: {video_fps:.1f}, Frames: {total_frames}, Duration: {duration_sec:.0f}s")
# Resume from checkpoint
start_frame = 0
frame_counter = 0
if checkpoint:
start_frame = checkpoint.get("last_frame", 0)
frame_counter = checkpoint.get("frame_counter", 0)
logger.info(f" Resuming from frame {start_frame}")
sampler = AdaptiveSampler(video_fps)
stats = {
"video": video_name,
"camera": camera_name,
"total_frames": total_frames,
"frames_extracted": 0,
"frames_with_person": 0,
"frames_without_person": 0,
"last_frame": start_frame,
"frame_counter": frame_counter,
"done": False,
}
# Seek to start position if resuming
if start_frame > 0:
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
current_frame_idx = start_frame
batch_frames = [] # [(frame_idx, frame_array, timestamp), ...]
batch_save_info = [] # [(output_img_path, output_lbl_path), ...]
last_checkpoint_time = time.time()
while True:
# Check dataset size cap periodically
if stats["frames_extracted"] % 500 == 0 and stats["frames_extracted"] > 0:
current_size = get_dataset_size_gb(cfg.DATASET_DIR)
if current_size >= cfg.MAX_DATASET_SIZE_GB:
logger.warning(f" ⚠ Dataset size cap reached ({current_size:.1f} GB). Stopping extraction.")
global_stats["size_cap_reached"] = True
break
# Calculate next frame to sample
interval = sampler.get_frame_interval()
target_frame = current_frame_idx + interval
if target_frame >= total_frames:
break
# Seek to target frame
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
if not ret:
break
current_frame_idx = target_frame
timestamp = current_frame_idx / video_fps
# Build unique output filename
frame_counter += 1
frame_id = f"{cam_safe}_{frame_counter:06d}"
img_path = os.path.join(output_images_dir, f"{frame_id}.jpg")
lbl_path = os.path.join(output_labels_dir, f"{frame_id}.txt")
batch_frames.append((current_frame_idx, frame, timestamp))
batch_save_info.append((img_path, lbl_path))
# Process batch when full
if len(batch_frames) >= cfg.BATCH_SIZE:
persons_in_batch = _process_batch(
model, batch_frames, batch_save_info, sampler, stats
)
batch_frames.clear()
batch_save_info.clear()
# Checkpoint every 30 seconds
if time.time() - last_checkpoint_time > 30:
stats["last_frame"] = current_frame_idx
stats["frame_counter"] = frame_counter
save_checkpoint(video_path, stats)
last_checkpoint_time = time.time()
# Progress logging every 1000 extractions
if stats["frames_extracted"] % 1000 == 0 and stats["frames_extracted"] > 0:
pct = (current_frame_idx / total_frames) * 100
logger.info(
f" Progress: {pct:.1f}% | Extracted: {stats['frames_extracted']} | "
f"Persons: {stats['frames_with_person']} | Mode: {sampler.state} | "
f"FPS: {sampler.current_sample_fps}"
)
# Process remaining batch
if batch_frames:
_process_batch(model, batch_frames, batch_save_info, sampler, stats)
cap.release()
stats["last_frame"] = current_frame_idx
stats["frame_counter"] = frame_counter
mark_video_done(video_path, stats)
logger.info(
f" ✓ Done: {video_name} | Extracted: {stats['frames_extracted']} | "
f"With person: {stats['frames_with_person']} | "
f"Without: {stats['frames_without_person']}"
)
return stats
def _process_batch(
model: YOLO,
batch_frames: list,
batch_save_info: list,
sampler: AdaptiveSampler,
stats: dict,
) -> int:
"""
Run YOLO inference on a batch of frames, save images + labels.
Returns number of frames with person detections.
"""
frames = [f[1] for f in batch_frames]
timestamps = [f[2] for f in batch_frames]
# Run YOLO batch inference on GPU
results = model.predict(
source=frames,
conf=cfg.DETECTION_CONF,
iou=cfg.DETECTION_IOU,
classes=[cfg.PERSON_CLASS_ID],
device=0,
verbose=False,
half=True, # FP16 for speed
)
persons_count = 0
for i, result in enumerate(results):
img_path, lbl_path = batch_save_info[i]
timestamp = timestamps[i]
frame = frames[i]
# Get person detections
boxes = result.boxes
person_boxes = boxes[boxes.cls == cfg.PERSON_CLASS_ID]
has_person = len(person_boxes) > 0
# Update adaptive sampler
sampler.update(timestamp, has_person)
# Save frame as JPEG
cv2.imwrite(img_path, frame, [cv2.IMWRITE_JPEG_QUALITY, cfg.JPEG_QUALITY])
# Save YOLO-format labels
h, w = frame.shape[:2]
with open(lbl_path, "w") as f:
for box in person_boxes:
# Convert to YOLO format: class x_center y_center width height (normalized)
xyxy = box.xyxy[0].cpu().numpy()
x_center = ((xyxy[0] + xyxy[2]) / 2) / w
y_center = ((xyxy[1] + xyxy[3]) / 2) / h
box_w = (xyxy[2] - xyxy[0]) / w
box_h = (xyxy[3] - xyxy[1]) / h
conf = float(box.conf[0])
f.write(f"0 {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}\n")
stats["frames_extracted"] += 1
if has_person:
stats["frames_with_person"] += 1
persons_count += 1
else:
stats["frames_without_person"] += 1
return persons_count
# ──────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────
def extract_all() -> dict:
"""
Main entry point: discover videos, extract dataset.
Returns {camera_name: [video_stats]}.
"""
logger.info("=" * 60)
logger.info("DATASET EXTRACTION STARTED")
logger.info("=" * 60)
# Load YOLO model on GPU
logger.info(f"Loading detector model: {cfg.DETECTOR_MODEL}")
model = YOLO(cfg.DETECTOR_MODEL)
model.to("cuda")
logger.info("Model loaded on GPU ✓")
# Discover cameras and videos
cameras = discover_videos()
total_videos = sum(len(v) for v in cameras.values())
logger.info(f"Found {len(cameras)} cameras, {total_videos} videos total")
# Create output directories (flat — split happens later)
images_dir = os.path.join(cfg.DATASET_DIR, "images", "all")
labels_dir = os.path.join(cfg.DATASET_DIR, "labels", "all")
os.makedirs(images_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)
# Save camera mapping for split script
camera_mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
# Load existing mapping if resuming
if os.path.exists(camera_mapping_path):
with open(camera_mapping_path, "r") as f:
camera_mapping = json.load(f)
else:
camera_mapping = {}
global_stats = {"size_cap_reached": False}
all_stats = {}
video_num = 0
for cam_name, video_list in cameras.items():
cam_safe = sanitize_camera_name(cam_name)
logger.info(f"\n{'' * 40}")
logger.info(f"Camera: {cam_name} ({len(video_list)} videos)")
logger.info(f"{'' * 40}")
cam_stats = []
for video_path in video_list:
video_num += 1
logger.info(f"\n[{video_num}/{total_videos}]")
if global_stats["size_cap_reached"]:
logger.warning("Size cap reached — skipping remaining videos.")
break
vstats = process_video(
model, video_path, cam_name,
images_dir, labels_dir, global_stats
)
cam_stats.append(vstats)
# Track which frames belong to which camera
if cam_name not in camera_mapping:
camera_mapping[cam_name] = {"safe_name": cam_safe, "frames": []}
# Collect frame IDs for this camera
frame_prefix = cam_safe + "_"
existing_frames = camera_mapping[cam_name].get("frames", [])
new_frames = [
f for f in os.listdir(images_dir)
if f.startswith(frame_prefix) and f.endswith(".jpg")
]
camera_mapping[cam_name]["frames"] = sorted(set(existing_frames + new_frames))
# Save mapping after each video
with open(camera_mapping_path, "w") as f:
json.dump(camera_mapping, f, indent=2)
all_stats[cam_name] = cam_stats
if global_stats["size_cap_reached"]:
break
# Summary
logger.info("\n" + "=" * 60)
logger.info("EXTRACTION COMPLETE")
logger.info("=" * 60)
total_extracted = 0
total_persons = 0
for cam, stats_list in all_stats.items():
for s in stats_list:
total_extracted += s.get("frames_extracted", 0)
total_persons += s.get("frames_with_person", 0)
dataset_size = get_dataset_size_gb(cfg.DATASET_DIR)
logger.info(f"Total frames extracted: {total_extracted}")
logger.info(f"Frames with persons: {total_persons}")
logger.info(f"Dataset size: {dataset_size:.2f} GB")
logger.info(f"Cameras processed: {len(cameras)}")
# Save final stats
stats_path = os.path.join(cfg.DATASET_DIR, "extraction_stats.json")
with open(stats_path, "w") as f:
json.dump({
"total_extracted": total_extracted,
"total_with_persons": total_persons,
"dataset_size_gb": round(dataset_size, 2),
"cameras": {k: len(v) for k, v in cameras.items()},
"all_stats": all_stats,
}, f, indent=2, default=str)
return all_stats
if __name__ == "__main__":
extract_all()

59
pipeline_config.py Normal file
View File

@ -0,0 +1,59 @@
"""
Centralized configuration for the person detection dataset pipeline.
All tunable parameters are defined here.
"""
import os
# ──────────────────────────────────────────────
# Paths
# ──────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VIDEO_DIR = os.path.join(BASE_DIR, "video_data")
DATASET_DIR = os.path.join(BASE_DIR, "dataset")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
LOG_DIR = os.path.join(BASE_DIR, "logs")
# Model paths
DETECTOR_MODEL = os.path.join(BASE_DIR, "yolo26x.pt") # Large model for auto-labeling
TRAIN_MODEL = "yolo26n.pt" # Nano model to train (auto-downloads)
# ──────────────────────────────────────────────
# Dataset Extraction
# ──────────────────────────────────────────────
MAX_DATASET_SIZE_GB = 50 # Stop extraction if dataset exceeds this
JPEG_QUALITY = 85 # JPEG save quality (1-100)
DETECTION_CONF = 0.35 # Min confidence for person detection
DETECTION_IOU = 0.45 # NMS IoU threshold
BATCH_SIZE = 16 # Frames per YOLO inference batch
PERSON_CLASS_ID = 0 # YOLO class ID for "person"
# ──────────────────────────────────────────────
# Adaptive Sampling
# ──────────────────────────────────────────────
BASE_FPS = 1.0 # Default: 1 frame per second
HIGH_FPS = 3.0 # When person detected: 3 frames per second
LOW_FPS = 0.5 # When idle (no person): 0.5 frames per second
HIGH_FPS_DURATION = 5 # Seconds to stay at high FPS after person detected
LOW_FPS_THRESHOLD = 10 # Seconds without person before dropping to low FPS
# ──────────────────────────────────────────────
# Train/Test Split (Camera-Level)
# ──────────────────────────────────────────────
TEST_CAMERAS = 4 # Number of cameras to hold out for testing
RANDOM_SEED = 42 # For reproducible camera selection
# ──────────────────────────────────────────────
# Training
# ──────────────────────────────────────────────
TRAIN_EPOCHS = 100
TRAIN_BATCH = 16 # Batch size for training (adjust for VRAM)
TRAIN_IMGSZ = 640 # Training image size
EARLY_STOP_PATIENCE = 15 # Stop if no improvement for N epochs
TRAIN_WORKERS = 8 # DataLoader workers
TRAIN_PROJECT = os.path.join(BASE_DIR, "runs", "detect")
TRAIN_NAME = "person_detection"
# ──────────────────────────────────────────────
# Video file extensions to process
# ──────────────────────────────────────────────
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv"}

173
run_pipeline.py Normal file
View File

@ -0,0 +1,173 @@
"""
run_pipeline.py
Master pipeline script that orchestrates the full workflow:
Phase 1: Extract frames from videos (adaptive sampling + YOLO detection)
Phase 2: Split dataset by camera (train/test)
Phase 3: Train yolo26n.pt on the dataset
Designed to run overnight resumable from any phase.
Usage:
python run_pipeline.py # Run full pipeline
python run_pipeline.py --phase 2 # Start from phase 2 (skip extraction)
python run_pipeline.py --phase 3 # Start from phase 3 (skip extract + split)
python run_pipeline.py --extract-only # Only extract (no split or train)
"""
import os
import sys
import argparse
import logging
import traceback
from datetime import datetime
import pipeline_config as cfg
# ──────────────────────────────────────────────
# Logging
# ──────────────────────────────────────────────
os.makedirs(cfg.LOG_DIR, exist_ok=True)
log_file = os.path.join(cfg.LOG_DIR, f"pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
def phase1_extract():
"""Phase 1: Extract dataset from videos."""
logger.info("\n" + "" * 60)
logger.info("█ PHASE 1: DATASET EXTRACTION")
logger.info("" * 60 + "\n")
from extract_dataset import extract_all
stats = extract_all()
return stats
def phase2_split():
"""Phase 2: Split dataset by camera."""
logger.info("\n" + "" * 60)
logger.info("█ PHASE 2: CAMERA-LEVEL TRAIN/TEST SPLIT")
logger.info("" * 60 + "\n")
from split_dataset import split_dataset
split_info = split_dataset()
return split_info
def phase3_train():
"""Phase 3: Train model."""
logger.info("\n" + "" * 60)
logger.info("█ PHASE 3: MODEL TRAINING")
logger.info("" * 60 + "\n")
from train_model import train_model
best_weights = train_model()
return best_weights
def run_pipeline(start_phase: int = 1, extract_only: bool = False):
"""Run the full pipeline from the specified starting phase."""
pipeline_start = datetime.now()
logger.info("" + "" * 58 + "")
logger.info("║ PERSON DETECTION PIPELINE ║")
logger.info("" + f"Started: {pipeline_start.strftime('%Y-%m-%d %H:%M:%S')}".ljust(57) + "")
logger.info("" + "" * 58 + "")
logger.info("")
logger.info(f"Configuration:")
logger.info(f" Video directory: {cfg.VIDEO_DIR}")
logger.info(f" Dataset output: {cfg.DATASET_DIR}")
logger.info(f" Detector model: {cfg.DETECTOR_MODEL}")
logger.info(f" Training model: {cfg.TRAIN_MODEL}")
logger.info(f" Max dataset size: {cfg.MAX_DATASET_SIZE_GB} GB")
logger.info(f" Starting phase: {start_phase}")
logger.info(f" Extract only: {extract_only}")
logger.info("")
try:
# Phase 1: Extract
if start_phase <= 1:
p1_start = datetime.now()
phase1_extract()
p1_duration = datetime.now() - p1_start
logger.info(f"\nPhase 1 completed in {p1_duration}")
if extract_only:
logger.info("Extract-only mode — stopping after Phase 1.")
return
# Phase 2: Split
if start_phase <= 2:
p2_start = datetime.now()
phase2_split()
p2_duration = datetime.now() - p2_start
logger.info(f"\nPhase 2 completed in {p2_duration}")
# Phase 3: Train
if start_phase <= 3:
p3_start = datetime.now()
best_weights = phase3_train()
p3_duration = datetime.now() - p3_start
logger.info(f"\nPhase 3 completed in {p3_duration}")
# Final summary
pipeline_end = datetime.now()
total_duration = pipeline_end - pipeline_start
logger.info("\n" + "" + "" * 58 + "")
logger.info("║ PIPELINE COMPLETED SUCCESSFULLY ║")
logger.info("" + f"Duration: {total_duration}".ljust(57) + "")
logger.info("" + "" * 58 + "")
except KeyboardInterrupt:
logger.warning("\n\nPipeline interrupted by user. Progress has been checkpointed.")
logger.warning("Re-run to resume from where you left off.")
sys.exit(1)
except Exception as e:
logger.error(f"\n\nPipeline failed with error: {e}")
logger.error(traceback.format_exc())
logger.error("Progress has been checkpointed. Fix the error and re-run.")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Person detection dataset pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python run_pipeline.py Run full pipeline
python run_pipeline.py --phase 2 Start from split phase
python run_pipeline.py --phase 3 Start from training phase
python run_pipeline.py --extract-only Only extract dataset
""",
)
parser.add_argument(
"--phase",
type=int,
default=1,
choices=[1, 2, 3],
help="Starting phase (1=extract, 2=split, 3=train)",
)
parser.add_argument(
"--extract-only",
action="store_true",
help="Only run extraction phase",
)
args = parser.parse_args()
run_pipeline(start_phase=args.phase, extract_only=args.extract_only)
if __name__ == "__main__":
main()

232
split_dataset.py Normal file
View File

@ -0,0 +1,232 @@
"""
split_dataset.py
Camera-level train/test split for the extracted dataset.
Splits ENTIRE cameras into train or test sets to prevent data leakage
from similar/consecutive surveillance frames.
- If >= 5 cameras: hold out 4 cameras for test
- If < 5 cameras: hold out 1 camera for test
"""
import os
import sys
import json
import shutil
import random
import logging
from pathlib import Path
from datetime import datetime
import pipeline_config as cfg
# ──────────────────────────────────────────────
# Logging
# ──────────────────────────────────────────────
os.makedirs(cfg.LOG_DIR, exist_ok=True)
log_file = os.path.join(cfg.LOG_DIR, f"split_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
def load_camera_mapping() -> dict:
"""Load the camera→frames mapping created during extraction."""
mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
if not os.path.exists(mapping_path):
logger.error(f"Camera mapping not found: {mapping_path}")
logger.error("Run extract_dataset.py first!")
sys.exit(1)
with open(mapping_path, "r") as f:
return json.load(f)
def split_cameras(camera_names: list[str]) -> tuple[list[str], list[str]]:
"""
Split camera names into train and test sets.
Returns (train_cameras, test_cameras).
"""
random.seed(cfg.RANDOM_SEED)
n_cameras = len(camera_names)
if n_cameras < 2:
logger.error(f"Need at least 2 cameras for splitting, found {n_cameras}")
sys.exit(1)
# Determine how many cameras to hold out
if n_cameras >= 5:
n_test = min(cfg.TEST_CAMERAS, n_cameras - 1) # At least 1 for train
else:
n_test = 1
shuffled = camera_names.copy()
random.shuffle(shuffled)
test_cameras = shuffled[:n_test]
train_cameras = shuffled[n_test:]
return train_cameras, test_cameras
def move_frames(
camera_mapping: dict,
camera_names: list[str],
split_name: str,
src_images_dir: str,
src_labels_dir: str,
):
"""Move frames belonging to given cameras into the split directory."""
dst_images = os.path.join(cfg.DATASET_DIR, "images", split_name)
dst_labels = os.path.join(cfg.DATASET_DIR, "labels", split_name)
os.makedirs(dst_images, exist_ok=True)
os.makedirs(dst_labels, exist_ok=True)
total_moved = 0
for cam_name in camera_names:
cam_data = camera_mapping[cam_name]
frames = cam_data.get("frames", [])
for img_filename in frames:
# Image
src_img = os.path.join(src_images_dir, img_filename)
dst_img = os.path.join(dst_images, img_filename)
if os.path.exists(src_img):
shutil.move(src_img, dst_img)
# Label
lbl_filename = img_filename.replace(".jpg", ".txt")
src_lbl = os.path.join(src_labels_dir, lbl_filename)
dst_lbl = os.path.join(dst_labels, lbl_filename)
if os.path.exists(src_lbl):
shutil.move(src_lbl, dst_lbl)
total_moved += 1
logger.info(f" {split_name}: {total_moved} frames from {len(camera_names)} cameras")
return total_moved
def create_dataset_yaml(train_cameras: list[str], test_cameras: list[str]):
"""Create dataset.yaml for YOLO training."""
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
# Use forward slashes for YOLO compatibility
dataset_path = cfg.DATASET_DIR.replace("\\", "/")
content = f"""# Auto-generated dataset config for person detection
# Generated: {datetime.now().isoformat()}
#
# Train cameras: {', '.join(train_cameras)}
# Test cameras: {', '.join(test_cameras)}
path: {dataset_path}
train: images/train
val: images/test
nc: 1
names: ['person']
"""
with open(yaml_path, "w") as f:
f.write(content)
logger.info(f"Created dataset.yaml at {yaml_path}")
return yaml_path
def split_dataset() -> dict:
"""
Main entry point: split extracted dataset by camera.
Returns split info dict.
"""
logger.info("=" * 60)
logger.info("DATASET SPLIT (CAMERA-LEVEL)")
logger.info("=" * 60)
# Load camera mapping
camera_mapping = load_camera_mapping()
camera_names = sorted(camera_mapping.keys())
logger.info(f"Found {len(camera_names)} cameras: {camera_names}")
# Count frames per camera
for cam in camera_names:
n_frames = len(camera_mapping[cam].get("frames", []))
logger.info(f" {cam}: {n_frames} frames")
# Split cameras
train_cameras, test_cameras = split_cameras(camera_names)
logger.info(f"\nTrain cameras ({len(train_cameras)}): {train_cameras}")
logger.info(f"Test cameras ({len(test_cameras)}): {test_cameras}")
# Source directories
src_images = os.path.join(cfg.DATASET_DIR, "images", "all")
src_labels = os.path.join(cfg.DATASET_DIR, "labels", "all")
if not os.path.exists(src_images):
logger.error(f"Source images directory not found: {src_images}")
sys.exit(1)
# Move frames to train/test directories
logger.info("\nMoving frames to split directories...")
n_train = move_frames(camera_mapping, train_cameras, "train", src_images, src_labels)
n_test = move_frames(camera_mapping, test_cameras, "test", src_images, src_labels)
# Clean up empty 'all' directory
try:
remaining_imgs = os.listdir(src_images)
if not remaining_imgs:
os.rmdir(src_images)
src_labels_check = os.path.join(cfg.DATASET_DIR, "labels", "all")
remaining_lbls = os.listdir(src_labels_check)
if not remaining_lbls:
os.rmdir(src_labels_check)
logger.info("Cleaned up empty 'all' directories")
else:
logger.warning(f"{len(remaining_imgs)} orphan images left in 'all' directory")
except Exception as e:
logger.warning(f"Cleanup note: {e}")
# Create dataset.yaml
yaml_path = create_dataset_yaml(train_cameras, test_cameras)
# Create classes.txt
classes_path = os.path.join(cfg.DATASET_DIR, "classes.txt")
with open(classes_path, "w") as f:
f.write("person\n")
# Save split info
split_info = {
"train_cameras": train_cameras,
"test_cameras": test_cameras,
"train_frames": n_train,
"test_frames": n_test,
"random_seed": cfg.RANDOM_SEED,
"yaml_path": yaml_path,
"timestamp": datetime.now().isoformat(),
}
split_info_path = os.path.join(cfg.DATASET_DIR, "split_info.json")
with open(split_info_path, "w") as f:
json.dump(split_info, f, indent=2)
logger.info("\n" + "=" * 60)
logger.info("SPLIT COMPLETE")
logger.info("=" * 60)
logger.info(f"Train: {n_train} frames from {len(train_cameras)} cameras")
logger.info(f"Test: {n_test} frames from {len(test_cameras)} cameras")
logger.info(f"Ratio: {n_train/(n_train+n_test)*100:.1f}% / {n_test/(n_train+n_test)*100:.1f}%")
return split_info
if __name__ == "__main__":
split_dataset()

178
train_model.py Normal file
View File

@ -0,0 +1,178 @@
"""
train_model.py
Fine-tune yolo26n.pt on the extracted person detection dataset.
Optimized for NVIDIA RTX A5000 (16GB VRAM):
- Mixed precision (AMP) enabled
- Batch size 16, image size 640
- Early stopping with patience 15
- Full YOLO augmentation pipeline
"""
import os
import sys
import json
import logging
from pathlib import Path
from datetime import datetime
from ultralytics import YOLO
import pipeline_config as cfg
# ──────────────────────────────────────────────
# Logging
# ──────────────────────────────────────────────
os.makedirs(cfg.LOG_DIR, exist_ok=True)
log_file = os.path.join(cfg.LOG_DIR, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
def verify_dataset() -> str:
"""Verify dataset structure and return path to dataset.yaml."""
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
if not os.path.exists(yaml_path):
logger.error(f"dataset.yaml not found at {yaml_path}")
logger.error("Run extract_dataset.py and split_dataset.py first!")
sys.exit(1)
# Check directories exist
train_images = os.path.join(cfg.DATASET_DIR, "images", "train")
test_images = os.path.join(cfg.DATASET_DIR, "images", "test")
if not os.path.exists(train_images):
logger.error(f"Train images directory not found: {train_images}")
sys.exit(1)
if not os.path.exists(test_images):
logger.error(f"Test images directory not found: {test_images}")
sys.exit(1)
n_train = len([f for f in os.listdir(train_images) if f.endswith(".jpg")])
n_test = len([f for f in os.listdir(test_images) if f.endswith(".jpg")])
logger.info(f"Dataset verified:")
logger.info(f" Train images: {n_train}")
logger.info(f" Test images: {n_test}")
if n_train == 0:
logger.error("No training images found!")
sys.exit(1)
if n_test == 0:
logger.warning("No test images found — training will proceed without validation.")
return yaml_path
def train_model() -> str:
"""
Main entry point: train yolo26n.pt on the person detection dataset.
Returns path to best weights.
"""
logger.info("=" * 60)
logger.info("MODEL TRAINING STARTED")
logger.info("=" * 60)
# Verify dataset
yaml_path = verify_dataset()
logger.info(f"Dataset config: {yaml_path}")
# Load base model
logger.info(f"Loading base model: {cfg.TRAIN_MODEL}")
model = YOLO(cfg.TRAIN_MODEL)
logger.info("Base model loaded ✓")
# Training configuration
train_args = {
"data": yaml_path,
"epochs": cfg.TRAIN_EPOCHS,
"batch": cfg.TRAIN_BATCH,
"imgsz": cfg.TRAIN_IMGSZ,
"device": 0, # GPU
"workers": cfg.TRAIN_WORKERS,
"patience": cfg.EARLY_STOP_PATIENCE,
"project": cfg.TRAIN_PROJECT,
"name": cfg.TRAIN_NAME,
"exist_ok": True, # Overwrite previous run
"pretrained": True,
"save": True,
"save_period": 10, # Save checkpoint every 10 epochs
"val": True,
"plots": True, # Generate training plots
"verbose": True,
# Augmentation (YOLO defaults are good, but explicit for clarity)
"hsv_h": 0.015,
"hsv_s": 0.7,
"hsv_v": 0.4,
"degrees": 0.0,
"translate": 0.1,
"scale": 0.5,
"shear": 0.0,
"flipud": 0.0, # No vertical flip (people don't appear upside down)
"fliplr": 0.5, # Horizontal flip is fine
"mosaic": 1.0,
"mixup": 0.1,
}
logger.info("Training configuration:")
for k, v in train_args.items():
logger.info(f" {k}: {v}")
# Start training
logger.info("\n" + "" * 40)
logger.info("TRAINING IN PROGRESS...")
logger.info("" * 40)
start_time = datetime.now()
results = model.train(**train_args)
end_time = datetime.now()
duration = end_time - start_time
# Results
logger.info("\n" + "=" * 60)
logger.info("TRAINING COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {duration}")
# Find best weights
best_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "best.pt")
last_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "last.pt")
if os.path.exists(best_weights):
logger.info(f"Best weights: {best_weights}")
# Copy best weights to project root for easy access
import shutil
output_model = os.path.join(cfg.BASE_DIR, "person_detector_best.pt")
shutil.copy2(best_weights, output_model)
logger.info(f"Copied best model to: {output_model}")
else:
logger.warning(f"Best weights not found at expected path: {best_weights}")
best_weights = last_weights
# Save training summary
summary = {
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"duration_seconds": duration.total_seconds(),
"best_weights": best_weights,
"train_args": {k: str(v) for k, v in train_args.items()},
}
summary_path = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "training_summary.json")
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
with open(summary_path, "w") as f:
json.dump(summary, f, indent=2)
return best_weights
if __name__ == "__main__":
train_model()