feat: implement adaptive video frame extraction pipeline with YOLO-based auto-labeling and checkpointing
parent
a1da64e017
commit
a7184a5773
|
|
@ -1,4 +1,6 @@
|
||||||
.env
|
.env
|
||||||
*.pt
|
*.pt
|
||||||
alerts/
|
alerts/
|
||||||
__pycache__
|
__pycache__
|
||||||
|
dataset/
|
||||||
|
video_data/
|
||||||
|
|
@ -0,0 +1,504 @@
|
||||||
|
"""
|
||||||
|
extract_dataset.py
|
||||||
|
──────────────────
|
||||||
|
Extracts person-detection frames from surveillance videos using adaptive
|
||||||
|
frame sampling and yolo26x.pt for auto-labeling.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Adaptive FPS: baseline 1 FPS, high 3 FPS (person), low 0.5 FPS (idle)
|
||||||
|
- GPU-accelerated YOLO inference in batches
|
||||||
|
- Per-video checkpointing for crash recovery
|
||||||
|
- 50 GB dataset size cap
|
||||||
|
- Organizes output in YOLO detection format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
import pipeline_config as cfg
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Logging
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||||
|
os.makedirs(cfg.CHECKPOINT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
log_file = os.path.join(cfg.LOG_DIR, f"extract_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Helpers
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
def get_dataset_size_gb(dataset_dir: str) -> float:
|
||||||
|
"""Return total size of the dataset directory in GB."""
|
||||||
|
total = 0
|
||||||
|
for dirpath, _, filenames in os.walk(dataset_dir):
|
||||||
|
for f in filenames:
|
||||||
|
total += os.path.getsize(os.path.join(dirpath, f))
|
||||||
|
return total / (1024 ** 3)
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_path(video_path: str) -> str:
|
||||||
|
"""Return checkpoint file path for a given video."""
|
||||||
|
video_name = Path(video_path).stem
|
||||||
|
return os.path.join(cfg.CHECKPOINT_DIR, f"{video_name}.json")
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(video_path: str) -> dict | None:
|
||||||
|
"""Load checkpoint for a video if it exists."""
|
||||||
|
cp_path = get_checkpoint_path(video_path)
|
||||||
|
if os.path.exists(cp_path):
|
||||||
|
with open(cp_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(video_path: str, data: dict):
|
||||||
|
"""Save checkpoint for a video."""
|
||||||
|
cp_path = get_checkpoint_path(video_path)
|
||||||
|
with open(cp_path, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def mark_video_done(video_path: str, stats: dict):
|
||||||
|
"""Mark a video as fully processed."""
|
||||||
|
stats["done"] = True
|
||||||
|
save_checkpoint(video_path, stats)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_videos() -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Discover all camera directories and their video files.
|
||||||
|
Returns {camera_name: [video_paths]}.
|
||||||
|
"""
|
||||||
|
cameras = {}
|
||||||
|
video_dir = Path(cfg.VIDEO_DIR)
|
||||||
|
|
||||||
|
if not video_dir.exists():
|
||||||
|
logger.error(f"Video directory not found: {cfg.VIDEO_DIR}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
for cam_dir in sorted(video_dir.iterdir()):
|
||||||
|
if cam_dir.is_dir():
|
||||||
|
videos = []
|
||||||
|
for ext in cfg.VIDEO_EXTENSIONS:
|
||||||
|
videos.extend(glob.glob(str(cam_dir / f"*{ext}")))
|
||||||
|
videos.extend(glob.glob(str(cam_dir / f"*{ext.upper()}")))
|
||||||
|
videos = sorted(set(videos))
|
||||||
|
if videos:
|
||||||
|
cameras[cam_dir.name] = videos
|
||||||
|
logger.info(f"Camera '{cam_dir.name}': {len(videos)} videos")
|
||||||
|
|
||||||
|
if not cameras:
|
||||||
|
logger.error("No video files found in any camera directory!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return cameras
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_camera_name(cam_name: str) -> str:
|
||||||
|
"""Create a filesystem-safe camera identifier."""
|
||||||
|
return cam_name.replace(" ", "_").replace("-", "_").replace("__", "_").strip("_").lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Adaptive Sampler State Machine
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
class AdaptiveSampler:
|
||||||
|
"""
|
||||||
|
State machine for adaptive frame sampling.
|
||||||
|
|
||||||
|
States:
|
||||||
|
- NORMAL: sample at BASE_FPS (1 fps)
|
||||||
|
- HIGH: sample at HIGH_FPS (3 fps) — person recently detected
|
||||||
|
- LOW: sample at LOW_FPS (0.5 fps) — long idle period
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, video_fps: float):
|
||||||
|
self.video_fps = video_fps
|
||||||
|
self.state = "NORMAL"
|
||||||
|
self.current_sample_fps = cfg.BASE_FPS
|
||||||
|
self.last_person_time = 0.0 # video timestamp of last person detection
|
||||||
|
self.no_person_streak = 0.0 # seconds since last person
|
||||||
|
|
||||||
|
def update(self, timestamp: float, person_detected: bool):
|
||||||
|
"""Update state based on detection result at given video timestamp."""
|
||||||
|
if person_detected:
|
||||||
|
self.last_person_time = timestamp
|
||||||
|
self.no_person_streak = 0.0
|
||||||
|
self.state = "HIGH"
|
||||||
|
self.current_sample_fps = cfg.HIGH_FPS
|
||||||
|
else:
|
||||||
|
self.no_person_streak = timestamp - self.last_person_time
|
||||||
|
|
||||||
|
if self.state == "HIGH":
|
||||||
|
# Stay high for HIGH_FPS_DURATION after last detection
|
||||||
|
if self.no_person_streak > cfg.HIGH_FPS_DURATION:
|
||||||
|
self.state = "NORMAL"
|
||||||
|
self.current_sample_fps = cfg.BASE_FPS
|
||||||
|
|
||||||
|
elif self.state == "NORMAL":
|
||||||
|
if self.no_person_streak > cfg.LOW_FPS_THRESHOLD:
|
||||||
|
self.state = "LOW"
|
||||||
|
self.current_sample_fps = cfg.LOW_FPS
|
||||||
|
|
||||||
|
# LOW stays LOW until a person is detected again
|
||||||
|
|
||||||
|
def get_frame_interval(self) -> int:
|
||||||
|
"""Return the frame interval (number of video frames to skip between samples)."""
|
||||||
|
interval = max(1, int(self.video_fps / self.current_sample_fps))
|
||||||
|
return interval
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Core Extraction
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
def process_video(
|
||||||
|
model: YOLO,
|
||||||
|
video_path: str,
|
||||||
|
camera_name: str,
|
||||||
|
output_images_dir: str,
|
||||||
|
output_labels_dir: str,
|
||||||
|
global_stats: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Process a single video file: extract frames with adaptive sampling,
|
||||||
|
detect persons with YOLO, save frames and labels.
|
||||||
|
|
||||||
|
Returns per-video stats dict.
|
||||||
|
"""
|
||||||
|
cam_safe = sanitize_camera_name(camera_name)
|
||||||
|
video_name = Path(video_path).stem
|
||||||
|
|
||||||
|
# Check if already done
|
||||||
|
checkpoint = load_checkpoint(video_path)
|
||||||
|
if checkpoint and checkpoint.get("done"):
|
||||||
|
logger.info(f" ⏭ Skipping (already done): {video_name}")
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
logger.error(f" ✗ Cannot open video: {video_path}")
|
||||||
|
return {"error": "cannot_open", "video": video_name}
|
||||||
|
|
||||||
|
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
duration_sec = total_frames / video_fps if video_fps > 0 else 0
|
||||||
|
|
||||||
|
logger.info(f" ▶ Processing: {video_name}")
|
||||||
|
logger.info(f" FPS: {video_fps:.1f}, Frames: {total_frames}, Duration: {duration_sec:.0f}s")
|
||||||
|
|
||||||
|
# Resume from checkpoint
|
||||||
|
start_frame = 0
|
||||||
|
frame_counter = 0
|
||||||
|
if checkpoint:
|
||||||
|
start_frame = checkpoint.get("last_frame", 0)
|
||||||
|
frame_counter = checkpoint.get("frame_counter", 0)
|
||||||
|
logger.info(f" Resuming from frame {start_frame}")
|
||||||
|
|
||||||
|
sampler = AdaptiveSampler(video_fps)
|
||||||
|
stats = {
|
||||||
|
"video": video_name,
|
||||||
|
"camera": camera_name,
|
||||||
|
"total_frames": total_frames,
|
||||||
|
"frames_extracted": 0,
|
||||||
|
"frames_with_person": 0,
|
||||||
|
"frames_without_person": 0,
|
||||||
|
"last_frame": start_frame,
|
||||||
|
"frame_counter": frame_counter,
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Seek to start position if resuming
|
||||||
|
if start_frame > 0:
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||||
|
|
||||||
|
current_frame_idx = start_frame
|
||||||
|
batch_frames = [] # [(frame_idx, frame_array, timestamp), ...]
|
||||||
|
batch_save_info = [] # [(output_img_path, output_lbl_path), ...]
|
||||||
|
|
||||||
|
last_checkpoint_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Check dataset size cap periodically
|
||||||
|
if stats["frames_extracted"] % 500 == 0 and stats["frames_extracted"] > 0:
|
||||||
|
current_size = get_dataset_size_gb(cfg.DATASET_DIR)
|
||||||
|
if current_size >= cfg.MAX_DATASET_SIZE_GB:
|
||||||
|
logger.warning(f" ⚠ Dataset size cap reached ({current_size:.1f} GB). Stopping extraction.")
|
||||||
|
global_stats["size_cap_reached"] = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate next frame to sample
|
||||||
|
interval = sampler.get_frame_interval()
|
||||||
|
target_frame = current_frame_idx + interval
|
||||||
|
|
||||||
|
if target_frame >= total_frames:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Seek to target frame
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
current_frame_idx = target_frame
|
||||||
|
timestamp = current_frame_idx / video_fps
|
||||||
|
|
||||||
|
# Build unique output filename
|
||||||
|
frame_counter += 1
|
||||||
|
frame_id = f"{cam_safe}_{frame_counter:06d}"
|
||||||
|
img_path = os.path.join(output_images_dir, f"{frame_id}.jpg")
|
||||||
|
lbl_path = os.path.join(output_labels_dir, f"{frame_id}.txt")
|
||||||
|
|
||||||
|
batch_frames.append((current_frame_idx, frame, timestamp))
|
||||||
|
batch_save_info.append((img_path, lbl_path))
|
||||||
|
|
||||||
|
# Process batch when full
|
||||||
|
if len(batch_frames) >= cfg.BATCH_SIZE:
|
||||||
|
persons_in_batch = _process_batch(
|
||||||
|
model, batch_frames, batch_save_info, sampler, stats
|
||||||
|
)
|
||||||
|
batch_frames.clear()
|
||||||
|
batch_save_info.clear()
|
||||||
|
|
||||||
|
# Checkpoint every 30 seconds
|
||||||
|
if time.time() - last_checkpoint_time > 30:
|
||||||
|
stats["last_frame"] = current_frame_idx
|
||||||
|
stats["frame_counter"] = frame_counter
|
||||||
|
save_checkpoint(video_path, stats)
|
||||||
|
last_checkpoint_time = time.time()
|
||||||
|
|
||||||
|
# Progress logging every 1000 extractions
|
||||||
|
if stats["frames_extracted"] % 1000 == 0 and stats["frames_extracted"] > 0:
|
||||||
|
pct = (current_frame_idx / total_frames) * 100
|
||||||
|
logger.info(
|
||||||
|
f" Progress: {pct:.1f}% | Extracted: {stats['frames_extracted']} | "
|
||||||
|
f"Persons: {stats['frames_with_person']} | Mode: {sampler.state} | "
|
||||||
|
f"FPS: {sampler.current_sample_fps}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process remaining batch
|
||||||
|
if batch_frames:
|
||||||
|
_process_batch(model, batch_frames, batch_save_info, sampler, stats)
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
stats["last_frame"] = current_frame_idx
|
||||||
|
stats["frame_counter"] = frame_counter
|
||||||
|
mark_video_done(video_path, stats)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" ✓ Done: {video_name} | Extracted: {stats['frames_extracted']} | "
|
||||||
|
f"With person: {stats['frames_with_person']} | "
|
||||||
|
f"Without: {stats['frames_without_person']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def _process_batch(
|
||||||
|
model: YOLO,
|
||||||
|
batch_frames: list,
|
||||||
|
batch_save_info: list,
|
||||||
|
sampler: AdaptiveSampler,
|
||||||
|
stats: dict,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Run YOLO inference on a batch of frames, save images + labels.
|
||||||
|
Returns number of frames with person detections.
|
||||||
|
"""
|
||||||
|
frames = [f[1] for f in batch_frames]
|
||||||
|
timestamps = [f[2] for f in batch_frames]
|
||||||
|
|
||||||
|
# Run YOLO batch inference on GPU
|
||||||
|
results = model.predict(
|
||||||
|
source=frames,
|
||||||
|
conf=cfg.DETECTION_CONF,
|
||||||
|
iou=cfg.DETECTION_IOU,
|
||||||
|
classes=[cfg.PERSON_CLASS_ID],
|
||||||
|
device=0,
|
||||||
|
verbose=False,
|
||||||
|
half=True, # FP16 for speed
|
||||||
|
)
|
||||||
|
|
||||||
|
persons_count = 0
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
img_path, lbl_path = batch_save_info[i]
|
||||||
|
timestamp = timestamps[i]
|
||||||
|
frame = frames[i]
|
||||||
|
|
||||||
|
# Get person detections
|
||||||
|
boxes = result.boxes
|
||||||
|
person_boxes = boxes[boxes.cls == cfg.PERSON_CLASS_ID]
|
||||||
|
has_person = len(person_boxes) > 0
|
||||||
|
|
||||||
|
# Update adaptive sampler
|
||||||
|
sampler.update(timestamp, has_person)
|
||||||
|
|
||||||
|
# Save frame as JPEG
|
||||||
|
cv2.imwrite(img_path, frame, [cv2.IMWRITE_JPEG_QUALITY, cfg.JPEG_QUALITY])
|
||||||
|
|
||||||
|
# Save YOLO-format labels
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
with open(lbl_path, "w") as f:
|
||||||
|
for box in person_boxes:
|
||||||
|
# Convert to YOLO format: class x_center y_center width height (normalized)
|
||||||
|
xyxy = box.xyxy[0].cpu().numpy()
|
||||||
|
x_center = ((xyxy[0] + xyxy[2]) / 2) / w
|
||||||
|
y_center = ((xyxy[1] + xyxy[3]) / 2) / h
|
||||||
|
box_w = (xyxy[2] - xyxy[0]) / w
|
||||||
|
box_h = (xyxy[3] - xyxy[1]) / h
|
||||||
|
conf = float(box.conf[0])
|
||||||
|
f.write(f"0 {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}\n")
|
||||||
|
|
||||||
|
stats["frames_extracted"] += 1
|
||||||
|
if has_person:
|
||||||
|
stats["frames_with_person"] += 1
|
||||||
|
persons_count += 1
|
||||||
|
else:
|
||||||
|
stats["frames_without_person"] += 1
|
||||||
|
|
||||||
|
return persons_count
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Main
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
def extract_all() -> dict:
|
||||||
|
"""
|
||||||
|
Main entry point: discover videos, extract dataset.
|
||||||
|
Returns {camera_name: [video_stats]}.
|
||||||
|
"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("DATASET EXTRACTION STARTED")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Load YOLO model on GPU
|
||||||
|
logger.info(f"Loading detector model: {cfg.DETECTOR_MODEL}")
|
||||||
|
model = YOLO(cfg.DETECTOR_MODEL)
|
||||||
|
model.to("cuda")
|
||||||
|
logger.info("Model loaded on GPU ✓")
|
||||||
|
|
||||||
|
# Discover cameras and videos
|
||||||
|
cameras = discover_videos()
|
||||||
|
total_videos = sum(len(v) for v in cameras.values())
|
||||||
|
logger.info(f"Found {len(cameras)} cameras, {total_videos} videos total")
|
||||||
|
|
||||||
|
# Create output directories (flat — split happens later)
|
||||||
|
images_dir = os.path.join(cfg.DATASET_DIR, "images", "all")
|
||||||
|
labels_dir = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
os.makedirs(labels_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Save camera mapping for split script
|
||||||
|
camera_mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
|
||||||
|
|
||||||
|
# Load existing mapping if resuming
|
||||||
|
if os.path.exists(camera_mapping_path):
|
||||||
|
with open(camera_mapping_path, "r") as f:
|
||||||
|
camera_mapping = json.load(f)
|
||||||
|
else:
|
||||||
|
camera_mapping = {}
|
||||||
|
|
||||||
|
global_stats = {"size_cap_reached": False}
|
||||||
|
all_stats = {}
|
||||||
|
video_num = 0
|
||||||
|
|
||||||
|
for cam_name, video_list in cameras.items():
|
||||||
|
cam_safe = sanitize_camera_name(cam_name)
|
||||||
|
logger.info(f"\n{'─' * 40}")
|
||||||
|
logger.info(f"Camera: {cam_name} ({len(video_list)} videos)")
|
||||||
|
logger.info(f"{'─' * 40}")
|
||||||
|
|
||||||
|
cam_stats = []
|
||||||
|
for video_path in video_list:
|
||||||
|
video_num += 1
|
||||||
|
logger.info(f"\n[{video_num}/{total_videos}]")
|
||||||
|
|
||||||
|
if global_stats["size_cap_reached"]:
|
||||||
|
logger.warning("Size cap reached — skipping remaining videos.")
|
||||||
|
break
|
||||||
|
|
||||||
|
vstats = process_video(
|
||||||
|
model, video_path, cam_name,
|
||||||
|
images_dir, labels_dir, global_stats
|
||||||
|
)
|
||||||
|
cam_stats.append(vstats)
|
||||||
|
|
||||||
|
# Track which frames belong to which camera
|
||||||
|
if cam_name not in camera_mapping:
|
||||||
|
camera_mapping[cam_name] = {"safe_name": cam_safe, "frames": []}
|
||||||
|
|
||||||
|
# Collect frame IDs for this camera
|
||||||
|
frame_prefix = cam_safe + "_"
|
||||||
|
existing_frames = camera_mapping[cam_name].get("frames", [])
|
||||||
|
new_frames = [
|
||||||
|
f for f in os.listdir(images_dir)
|
||||||
|
if f.startswith(frame_prefix) and f.endswith(".jpg")
|
||||||
|
]
|
||||||
|
camera_mapping[cam_name]["frames"] = sorted(set(existing_frames + new_frames))
|
||||||
|
|
||||||
|
# Save mapping after each video
|
||||||
|
with open(camera_mapping_path, "w") as f:
|
||||||
|
json.dump(camera_mapping, f, indent=2)
|
||||||
|
|
||||||
|
all_stats[cam_name] = cam_stats
|
||||||
|
|
||||||
|
if global_stats["size_cap_reached"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("EXTRACTION COMPLETE")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
total_extracted = 0
|
||||||
|
total_persons = 0
|
||||||
|
for cam, stats_list in all_stats.items():
|
||||||
|
for s in stats_list:
|
||||||
|
total_extracted += s.get("frames_extracted", 0)
|
||||||
|
total_persons += s.get("frames_with_person", 0)
|
||||||
|
|
||||||
|
dataset_size = get_dataset_size_gb(cfg.DATASET_DIR)
|
||||||
|
logger.info(f"Total frames extracted: {total_extracted}")
|
||||||
|
logger.info(f"Frames with persons: {total_persons}")
|
||||||
|
logger.info(f"Dataset size: {dataset_size:.2f} GB")
|
||||||
|
logger.info(f"Cameras processed: {len(cameras)}")
|
||||||
|
|
||||||
|
# Save final stats
|
||||||
|
stats_path = os.path.join(cfg.DATASET_DIR, "extraction_stats.json")
|
||||||
|
with open(stats_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"total_extracted": total_extracted,
|
||||||
|
"total_with_persons": total_persons,
|
||||||
|
"dataset_size_gb": round(dataset_size, 2),
|
||||||
|
"cameras": {k: len(v) for k, v in cameras.items()},
|
||||||
|
"all_stats": all_stats,
|
||||||
|
}, f, indent=2, default=str)
|
||||||
|
|
||||||
|
return all_stats
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
extract_all()
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
"""
|
||||||
|
Centralized configuration for the person detection dataset pipeline.
|
||||||
|
All tunable parameters are defined here.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Paths
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
VIDEO_DIR = os.path.join(BASE_DIR, "video_data")
|
||||||
|
DATASET_DIR = os.path.join(BASE_DIR, "dataset")
|
||||||
|
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
|
||||||
|
LOG_DIR = os.path.join(BASE_DIR, "logs")
|
||||||
|
|
||||||
|
# Model paths
|
||||||
|
DETECTOR_MODEL = os.path.join(BASE_DIR, "yolo26x.pt") # Large model for auto-labeling
|
||||||
|
TRAIN_MODEL = "yolo26n.pt" # Nano model to train (auto-downloads)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Dataset Extraction
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
MAX_DATASET_SIZE_GB = 50 # Stop extraction if dataset exceeds this
|
||||||
|
JPEG_QUALITY = 85 # JPEG save quality (1-100)
|
||||||
|
DETECTION_CONF = 0.35 # Min confidence for person detection
|
||||||
|
DETECTION_IOU = 0.45 # NMS IoU threshold
|
||||||
|
BATCH_SIZE = 16 # Frames per YOLO inference batch
|
||||||
|
PERSON_CLASS_ID = 0 # YOLO class ID for "person"
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Adaptive Sampling
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
BASE_FPS = 1.0 # Default: 1 frame per second
|
||||||
|
HIGH_FPS = 3.0 # When person detected: 3 frames per second
|
||||||
|
LOW_FPS = 0.5 # When idle (no person): 0.5 frames per second
|
||||||
|
HIGH_FPS_DURATION = 5 # Seconds to stay at high FPS after person detected
|
||||||
|
LOW_FPS_THRESHOLD = 10 # Seconds without person before dropping to low FPS
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Train/Test Split (Camera-Level)
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
TEST_CAMERAS = 4 # Number of cameras to hold out for testing
|
||||||
|
RANDOM_SEED = 42 # For reproducible camera selection
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Training
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
TRAIN_EPOCHS = 100
|
||||||
|
TRAIN_BATCH = 16 # Batch size for training (adjust for VRAM)
|
||||||
|
TRAIN_IMGSZ = 640 # Training image size
|
||||||
|
EARLY_STOP_PATIENCE = 15 # Stop if no improvement for N epochs
|
||||||
|
TRAIN_WORKERS = 8 # DataLoader workers
|
||||||
|
TRAIN_PROJECT = os.path.join(BASE_DIR, "runs", "detect")
|
||||||
|
TRAIN_NAME = "person_detection"
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Video file extensions to process
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv"}
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
"""
|
||||||
|
run_pipeline.py
|
||||||
|
───────────────
|
||||||
|
Master pipeline script that orchestrates the full workflow:
|
||||||
|
|
||||||
|
Phase 1: Extract frames from videos (adaptive sampling + YOLO detection)
|
||||||
|
Phase 2: Split dataset by camera (train/test)
|
||||||
|
Phase 3: Train yolo26n.pt on the dataset
|
||||||
|
|
||||||
|
Designed to run overnight — resumable from any phase.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python run_pipeline.py # Run full pipeline
|
||||||
|
python run_pipeline.py --phase 2 # Start from phase 2 (skip extraction)
|
||||||
|
python run_pipeline.py --phase 3 # Start from phase 3 (skip extract + split)
|
||||||
|
python run_pipeline.py --extract-only # Only extract (no split or train)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pipeline_config as cfg
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Logging
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||||
|
log_file = os.path.join(cfg.LOG_DIR, f"pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def phase1_extract():
|
||||||
|
"""Phase 1: Extract dataset from videos."""
|
||||||
|
logger.info("\n" + "█" * 60)
|
||||||
|
logger.info("█ PHASE 1: DATASET EXTRACTION")
|
||||||
|
logger.info("█" * 60 + "\n")
|
||||||
|
|
||||||
|
from extract_dataset import extract_all
|
||||||
|
stats = extract_all()
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def phase2_split():
|
||||||
|
"""Phase 2: Split dataset by camera."""
|
||||||
|
logger.info("\n" + "█" * 60)
|
||||||
|
logger.info("█ PHASE 2: CAMERA-LEVEL TRAIN/TEST SPLIT")
|
||||||
|
logger.info("█" * 60 + "\n")
|
||||||
|
|
||||||
|
from split_dataset import split_dataset
|
||||||
|
split_info = split_dataset()
|
||||||
|
return split_info
|
||||||
|
|
||||||
|
|
||||||
|
def phase3_train():
|
||||||
|
"""Phase 3: Train model."""
|
||||||
|
logger.info("\n" + "█" * 60)
|
||||||
|
logger.info("█ PHASE 3: MODEL TRAINING")
|
||||||
|
logger.info("█" * 60 + "\n")
|
||||||
|
|
||||||
|
from train_model import train_model
|
||||||
|
best_weights = train_model()
|
||||||
|
return best_weights
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline(start_phase: int = 1, extract_only: bool = False):
|
||||||
|
"""Run the full pipeline from the specified starting phase."""
|
||||||
|
pipeline_start = datetime.now()
|
||||||
|
|
||||||
|
logger.info("╔" + "═" * 58 + "╗")
|
||||||
|
logger.info("║ PERSON DETECTION PIPELINE ║")
|
||||||
|
logger.info("║ " + f"Started: {pipeline_start.strftime('%Y-%m-%d %H:%M:%S')}".ljust(57) + "║")
|
||||||
|
logger.info("╚" + "═" * 58 + "╝")
|
||||||
|
logger.info("")
|
||||||
|
logger.info(f"Configuration:")
|
||||||
|
logger.info(f" Video directory: {cfg.VIDEO_DIR}")
|
||||||
|
logger.info(f" Dataset output: {cfg.DATASET_DIR}")
|
||||||
|
logger.info(f" Detector model: {cfg.DETECTOR_MODEL}")
|
||||||
|
logger.info(f" Training model: {cfg.TRAIN_MODEL}")
|
||||||
|
logger.info(f" Max dataset size: {cfg.MAX_DATASET_SIZE_GB} GB")
|
||||||
|
logger.info(f" Starting phase: {start_phase}")
|
||||||
|
logger.info(f" Extract only: {extract_only}")
|
||||||
|
logger.info("")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Phase 1: Extract
|
||||||
|
if start_phase <= 1:
|
||||||
|
p1_start = datetime.now()
|
||||||
|
phase1_extract()
|
||||||
|
p1_duration = datetime.now() - p1_start
|
||||||
|
logger.info(f"\nPhase 1 completed in {p1_duration}")
|
||||||
|
|
||||||
|
if extract_only:
|
||||||
|
logger.info("Extract-only mode — stopping after Phase 1.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Phase 2: Split
|
||||||
|
if start_phase <= 2:
|
||||||
|
p2_start = datetime.now()
|
||||||
|
phase2_split()
|
||||||
|
p2_duration = datetime.now() - p2_start
|
||||||
|
logger.info(f"\nPhase 2 completed in {p2_duration}")
|
||||||
|
|
||||||
|
# Phase 3: Train
|
||||||
|
if start_phase <= 3:
|
||||||
|
p3_start = datetime.now()
|
||||||
|
best_weights = phase3_train()
|
||||||
|
p3_duration = datetime.now() - p3_start
|
||||||
|
logger.info(f"\nPhase 3 completed in {p3_duration}")
|
||||||
|
|
||||||
|
# Final summary
|
||||||
|
pipeline_end = datetime.now()
|
||||||
|
total_duration = pipeline_end - pipeline_start
|
||||||
|
|
||||||
|
logger.info("\n" + "╔" + "═" * 58 + "╗")
|
||||||
|
logger.info("║ PIPELINE COMPLETED SUCCESSFULLY ║")
|
||||||
|
logger.info("║ " + f"Duration: {total_duration}".ljust(57) + "║")
|
||||||
|
logger.info("╚" + "═" * 58 + "╝")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.warning("\n\nPipeline interrupted by user. Progress has been checkpointed.")
|
||||||
|
logger.warning("Re-run to resume from where you left off.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"\n\nPipeline failed with error: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error("Progress has been checkpointed. Fix the error and re-run.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Person detection dataset pipeline",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
python run_pipeline.py Run full pipeline
|
||||||
|
python run_pipeline.py --phase 2 Start from split phase
|
||||||
|
python run_pipeline.py --phase 3 Start from training phase
|
||||||
|
python run_pipeline.py --extract-only Only extract dataset
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--phase",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
choices=[1, 2, 3],
|
||||||
|
help="Starting phase (1=extract, 2=split, 3=train)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--extract-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run extraction phase",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
run_pipeline(start_phase=args.phase, extract_only=args.extract_only)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,232 @@
|
||||||
|
"""
|
||||||
|
split_dataset.py
|
||||||
|
────────────────
|
||||||
|
Camera-level train/test split for the extracted dataset.
|
||||||
|
|
||||||
|
Splits ENTIRE cameras into train or test sets to prevent data leakage
|
||||||
|
from similar/consecutive surveillance frames.
|
||||||
|
|
||||||
|
- If >= 5 cameras: hold out 4 cameras for test
|
||||||
|
- If < 5 cameras: hold out 1 camera for test
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pipeline_config as cfg
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Logging
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||||
|
log_file = os.path.join(cfg.LOG_DIR, f"split_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_camera_mapping() -> dict:
|
||||||
|
"""Load the camera→frames mapping created during extraction."""
|
||||||
|
mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
|
||||||
|
if not os.path.exists(mapping_path):
|
||||||
|
logger.error(f"Camera mapping not found: {mapping_path}")
|
||||||
|
logger.error("Run extract_dataset.py first!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
with open(mapping_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def split_cameras(camera_names: list[str]) -> tuple[list[str], list[str]]:
|
||||||
|
"""
|
||||||
|
Split camera names into train and test sets.
|
||||||
|
Returns (train_cameras, test_cameras).
|
||||||
|
"""
|
||||||
|
random.seed(cfg.RANDOM_SEED)
|
||||||
|
|
||||||
|
n_cameras = len(camera_names)
|
||||||
|
if n_cameras < 2:
|
||||||
|
logger.error(f"Need at least 2 cameras for splitting, found {n_cameras}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Determine how many cameras to hold out
|
||||||
|
if n_cameras >= 5:
|
||||||
|
n_test = min(cfg.TEST_CAMERAS, n_cameras - 1) # At least 1 for train
|
||||||
|
else:
|
||||||
|
n_test = 1
|
||||||
|
|
||||||
|
shuffled = camera_names.copy()
|
||||||
|
random.shuffle(shuffled)
|
||||||
|
|
||||||
|
test_cameras = shuffled[:n_test]
|
||||||
|
train_cameras = shuffled[n_test:]
|
||||||
|
|
||||||
|
return train_cameras, test_cameras
|
||||||
|
|
||||||
|
|
||||||
|
def move_frames(
|
||||||
|
camera_mapping: dict,
|
||||||
|
camera_names: list[str],
|
||||||
|
split_name: str,
|
||||||
|
src_images_dir: str,
|
||||||
|
src_labels_dir: str,
|
||||||
|
):
|
||||||
|
"""Move frames belonging to given cameras into the split directory."""
|
||||||
|
dst_images = os.path.join(cfg.DATASET_DIR, "images", split_name)
|
||||||
|
dst_labels = os.path.join(cfg.DATASET_DIR, "labels", split_name)
|
||||||
|
os.makedirs(dst_images, exist_ok=True)
|
||||||
|
os.makedirs(dst_labels, exist_ok=True)
|
||||||
|
|
||||||
|
total_moved = 0
|
||||||
|
|
||||||
|
for cam_name in camera_names:
|
||||||
|
cam_data = camera_mapping[cam_name]
|
||||||
|
frames = cam_data.get("frames", [])
|
||||||
|
|
||||||
|
for img_filename in frames:
|
||||||
|
# Image
|
||||||
|
src_img = os.path.join(src_images_dir, img_filename)
|
||||||
|
dst_img = os.path.join(dst_images, img_filename)
|
||||||
|
if os.path.exists(src_img):
|
||||||
|
shutil.move(src_img, dst_img)
|
||||||
|
|
||||||
|
# Label
|
||||||
|
lbl_filename = img_filename.replace(".jpg", ".txt")
|
||||||
|
src_lbl = os.path.join(src_labels_dir, lbl_filename)
|
||||||
|
dst_lbl = os.path.join(dst_labels, lbl_filename)
|
||||||
|
if os.path.exists(src_lbl):
|
||||||
|
shutil.move(src_lbl, dst_lbl)
|
||||||
|
|
||||||
|
total_moved += 1
|
||||||
|
|
||||||
|
logger.info(f" {split_name}: {total_moved} frames from {len(camera_names)} cameras")
|
||||||
|
return total_moved
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_yaml(train_cameras: list[str], test_cameras: list[str]):
|
||||||
|
"""Create dataset.yaml for YOLO training."""
|
||||||
|
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
|
||||||
|
|
||||||
|
# Use forward slashes for YOLO compatibility
|
||||||
|
dataset_path = cfg.DATASET_DIR.replace("\\", "/")
|
||||||
|
|
||||||
|
content = f"""# Auto-generated dataset config for person detection
|
||||||
|
# Generated: {datetime.now().isoformat()}
|
||||||
|
#
|
||||||
|
# Train cameras: {', '.join(train_cameras)}
|
||||||
|
# Test cameras: {', '.join(test_cameras)}
|
||||||
|
|
||||||
|
path: {dataset_path}
|
||||||
|
train: images/train
|
||||||
|
val: images/test
|
||||||
|
|
||||||
|
nc: 1
|
||||||
|
names: ['person']
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(yaml_path, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
logger.info(f"Created dataset.yaml at {yaml_path}")
|
||||||
|
return yaml_path
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset() -> dict:
|
||||||
|
"""
|
||||||
|
Main entry point: split extracted dataset by camera.
|
||||||
|
Returns split info dict.
|
||||||
|
"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("DATASET SPLIT (CAMERA-LEVEL)")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Load camera mapping
|
||||||
|
camera_mapping = load_camera_mapping()
|
||||||
|
camera_names = sorted(camera_mapping.keys())
|
||||||
|
logger.info(f"Found {len(camera_names)} cameras: {camera_names}")
|
||||||
|
|
||||||
|
# Count frames per camera
|
||||||
|
for cam in camera_names:
|
||||||
|
n_frames = len(camera_mapping[cam].get("frames", []))
|
||||||
|
logger.info(f" {cam}: {n_frames} frames")
|
||||||
|
|
||||||
|
# Split cameras
|
||||||
|
train_cameras, test_cameras = split_cameras(camera_names)
|
||||||
|
logger.info(f"\nTrain cameras ({len(train_cameras)}): {train_cameras}")
|
||||||
|
logger.info(f"Test cameras ({len(test_cameras)}): {test_cameras}")
|
||||||
|
|
||||||
|
# Source directories
|
||||||
|
src_images = os.path.join(cfg.DATASET_DIR, "images", "all")
|
||||||
|
src_labels = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
||||||
|
|
||||||
|
if not os.path.exists(src_images):
|
||||||
|
logger.error(f"Source images directory not found: {src_images}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Move frames to train/test directories
|
||||||
|
logger.info("\nMoving frames to split directories...")
|
||||||
|
n_train = move_frames(camera_mapping, train_cameras, "train", src_images, src_labels)
|
||||||
|
n_test = move_frames(camera_mapping, test_cameras, "test", src_images, src_labels)
|
||||||
|
|
||||||
|
# Clean up empty 'all' directory
|
||||||
|
try:
|
||||||
|
remaining_imgs = os.listdir(src_images)
|
||||||
|
if not remaining_imgs:
|
||||||
|
os.rmdir(src_images)
|
||||||
|
src_labels_check = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
||||||
|
remaining_lbls = os.listdir(src_labels_check)
|
||||||
|
if not remaining_lbls:
|
||||||
|
os.rmdir(src_labels_check)
|
||||||
|
logger.info("Cleaned up empty 'all' directories")
|
||||||
|
else:
|
||||||
|
logger.warning(f"{len(remaining_imgs)} orphan images left in 'all' directory")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cleanup note: {e}")
|
||||||
|
|
||||||
|
# Create dataset.yaml
|
||||||
|
yaml_path = create_dataset_yaml(train_cameras, test_cameras)
|
||||||
|
|
||||||
|
# Create classes.txt
|
||||||
|
classes_path = os.path.join(cfg.DATASET_DIR, "classes.txt")
|
||||||
|
with open(classes_path, "w") as f:
|
||||||
|
f.write("person\n")
|
||||||
|
|
||||||
|
# Save split info
|
||||||
|
split_info = {
|
||||||
|
"train_cameras": train_cameras,
|
||||||
|
"test_cameras": test_cameras,
|
||||||
|
"train_frames": n_train,
|
||||||
|
"test_frames": n_test,
|
||||||
|
"random_seed": cfg.RANDOM_SEED,
|
||||||
|
"yaml_path": yaml_path,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
split_info_path = os.path.join(cfg.DATASET_DIR, "split_info.json")
|
||||||
|
with open(split_info_path, "w") as f:
|
||||||
|
json.dump(split_info, f, indent=2)
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("SPLIT COMPLETE")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"Train: {n_train} frames from {len(train_cameras)} cameras")
|
||||||
|
logger.info(f"Test: {n_test} frames from {len(test_cameras)} cameras")
|
||||||
|
logger.info(f"Ratio: {n_train/(n_train+n_test)*100:.1f}% / {n_test/(n_train+n_test)*100:.1f}%")
|
||||||
|
|
||||||
|
return split_info
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
split_dataset()
|
||||||
|
|
@ -0,0 +1,178 @@
|
||||||
|
"""
|
||||||
|
train_model.py
|
||||||
|
──────────────
|
||||||
|
Fine-tune yolo26n.pt on the extracted person detection dataset.
|
||||||
|
|
||||||
|
Optimized for NVIDIA RTX A5000 (16GB VRAM):
|
||||||
|
- Mixed precision (AMP) enabled
|
||||||
|
- Batch size 16, image size 640
|
||||||
|
- Early stopping with patience 15
|
||||||
|
- Full YOLO augmentation pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
import pipeline_config as cfg
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Logging
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
||||||
|
log_file = os.path.join(cfg.LOG_DIR, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_dataset() -> str:
|
||||||
|
"""Verify dataset structure and return path to dataset.yaml."""
|
||||||
|
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
|
||||||
|
if not os.path.exists(yaml_path):
|
||||||
|
logger.error(f"dataset.yaml not found at {yaml_path}")
|
||||||
|
logger.error("Run extract_dataset.py and split_dataset.py first!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Check directories exist
|
||||||
|
train_images = os.path.join(cfg.DATASET_DIR, "images", "train")
|
||||||
|
test_images = os.path.join(cfg.DATASET_DIR, "images", "test")
|
||||||
|
|
||||||
|
if not os.path.exists(train_images):
|
||||||
|
logger.error(f"Train images directory not found: {train_images}")
|
||||||
|
sys.exit(1)
|
||||||
|
if not os.path.exists(test_images):
|
||||||
|
logger.error(f"Test images directory not found: {test_images}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
n_train = len([f for f in os.listdir(train_images) if f.endswith(".jpg")])
|
||||||
|
n_test = len([f for f in os.listdir(test_images) if f.endswith(".jpg")])
|
||||||
|
|
||||||
|
logger.info(f"Dataset verified:")
|
||||||
|
logger.info(f" Train images: {n_train}")
|
||||||
|
logger.info(f" Test images: {n_test}")
|
||||||
|
|
||||||
|
if n_train == 0:
|
||||||
|
logger.error("No training images found!")
|
||||||
|
sys.exit(1)
|
||||||
|
if n_test == 0:
|
||||||
|
logger.warning("No test images found — training will proceed without validation.")
|
||||||
|
|
||||||
|
return yaml_path
|
||||||
|
|
||||||
|
|
||||||
|
def train_model() -> str:
|
||||||
|
"""
|
||||||
|
Main entry point: train yolo26n.pt on the person detection dataset.
|
||||||
|
Returns path to best weights.
|
||||||
|
"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("MODEL TRAINING STARTED")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Verify dataset
|
||||||
|
yaml_path = verify_dataset()
|
||||||
|
logger.info(f"Dataset config: {yaml_path}")
|
||||||
|
|
||||||
|
# Load base model
|
||||||
|
logger.info(f"Loading base model: {cfg.TRAIN_MODEL}")
|
||||||
|
model = YOLO(cfg.TRAIN_MODEL)
|
||||||
|
logger.info("Base model loaded ✓")
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
train_args = {
|
||||||
|
"data": yaml_path,
|
||||||
|
"epochs": cfg.TRAIN_EPOCHS,
|
||||||
|
"batch": cfg.TRAIN_BATCH,
|
||||||
|
"imgsz": cfg.TRAIN_IMGSZ,
|
||||||
|
"device": 0, # GPU
|
||||||
|
"workers": cfg.TRAIN_WORKERS,
|
||||||
|
"patience": cfg.EARLY_STOP_PATIENCE,
|
||||||
|
"project": cfg.TRAIN_PROJECT,
|
||||||
|
"name": cfg.TRAIN_NAME,
|
||||||
|
"exist_ok": True, # Overwrite previous run
|
||||||
|
"pretrained": True,
|
||||||
|
"save": True,
|
||||||
|
"save_period": 10, # Save checkpoint every 10 epochs
|
||||||
|
"val": True,
|
||||||
|
"plots": True, # Generate training plots
|
||||||
|
"verbose": True,
|
||||||
|
# Augmentation (YOLO defaults are good, but explicit for clarity)
|
||||||
|
"hsv_h": 0.015,
|
||||||
|
"hsv_s": 0.7,
|
||||||
|
"hsv_v": 0.4,
|
||||||
|
"degrees": 0.0,
|
||||||
|
"translate": 0.1,
|
||||||
|
"scale": 0.5,
|
||||||
|
"shear": 0.0,
|
||||||
|
"flipud": 0.0, # No vertical flip (people don't appear upside down)
|
||||||
|
"fliplr": 0.5, # Horizontal flip is fine
|
||||||
|
"mosaic": 1.0,
|
||||||
|
"mixup": 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Training configuration:")
|
||||||
|
for k, v in train_args.items():
|
||||||
|
logger.info(f" {k}: {v}")
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
logger.info("\n" + "─" * 40)
|
||||||
|
logger.info("TRAINING IN PROGRESS...")
|
||||||
|
logger.info("─" * 40)
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
results = model.train(**train_args)
|
||||||
|
end_time = datetime.now()
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
# Results
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("TRAINING COMPLETE")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"Duration: {duration}")
|
||||||
|
|
||||||
|
# Find best weights
|
||||||
|
best_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "best.pt")
|
||||||
|
last_weights = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "weights", "last.pt")
|
||||||
|
|
||||||
|
if os.path.exists(best_weights):
|
||||||
|
logger.info(f"Best weights: {best_weights}")
|
||||||
|
# Copy best weights to project root for easy access
|
||||||
|
import shutil
|
||||||
|
output_model = os.path.join(cfg.BASE_DIR, "person_detector_best.pt")
|
||||||
|
shutil.copy2(best_weights, output_model)
|
||||||
|
logger.info(f"Copied best model to: {output_model}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Best weights not found at expected path: {best_weights}")
|
||||||
|
best_weights = last_weights
|
||||||
|
|
||||||
|
# Save training summary
|
||||||
|
summary = {
|
||||||
|
"start_time": start_time.isoformat(),
|
||||||
|
"end_time": end_time.isoformat(),
|
||||||
|
"duration_seconds": duration.total_seconds(),
|
||||||
|
"best_weights": best_weights,
|
||||||
|
"train_args": {k: str(v) for k, v in train_args.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
summary_path = os.path.join(cfg.TRAIN_PROJECT, cfg.TRAIN_NAME, "training_summary.json")
|
||||||
|
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
|
||||||
|
with open(summary_path, "w") as f:
|
||||||
|
json.dump(summary, f, indent=2)
|
||||||
|
|
||||||
|
return best_weights
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_model()
|
||||||
Loading…
Reference in New Issue