505 lines
17 KiB
Python
505 lines
17 KiB
Python
"""
|
|
extract_dataset.py
|
|
──────────────────
|
|
Extracts person-detection frames from surveillance videos using adaptive
|
|
frame sampling and yolo26x.pt for auto-labeling.
|
|
|
|
Features:
|
|
- Adaptive FPS: baseline 1 FPS, high 3 FPS (person), low 0.5 FPS (idle)
|
|
- GPU-accelerated YOLO inference in batches
|
|
- Per-video checkpointing for crash recovery
|
|
- 50 GB dataset size cap
|
|
- Organizes output in YOLO detection format
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import cv2
|
|
import json
|
|
import time
|
|
import glob
|
|
import logging
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
from ultralytics import YOLO
|
|
|
|
import pipeline_config as cfg
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Logging
|
|
# ──────────────────────────────────────────────
|
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
|
os.makedirs(cfg.CHECKPOINT_DIR, exist_ok=True)
|
|
|
|
log_file = os.path.join(cfg.LOG_DIR, f"extract_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(log_file),
|
|
logging.StreamHandler(sys.stdout),
|
|
],
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Helpers
|
|
# ──────────────────────────────────────────────
|
|
def get_dataset_size_gb(dataset_dir: str) -> float:
|
|
"""Return total size of the dataset directory in GB."""
|
|
total = 0
|
|
for dirpath, _, filenames in os.walk(dataset_dir):
|
|
for f in filenames:
|
|
total += os.path.getsize(os.path.join(dirpath, f))
|
|
return total / (1024 ** 3)
|
|
|
|
|
|
def get_checkpoint_path(video_path: str) -> str:
|
|
"""Return checkpoint file path for a given video."""
|
|
video_name = Path(video_path).stem
|
|
return os.path.join(cfg.CHECKPOINT_DIR, f"{video_name}.json")
|
|
|
|
|
|
def load_checkpoint(video_path: str) -> dict | None:
|
|
"""Load checkpoint for a video if it exists."""
|
|
cp_path = get_checkpoint_path(video_path)
|
|
if os.path.exists(cp_path):
|
|
with open(cp_path, "r") as f:
|
|
return json.load(f)
|
|
return None
|
|
|
|
|
|
def save_checkpoint(video_path: str, data: dict):
|
|
"""Save checkpoint for a video."""
|
|
cp_path = get_checkpoint_path(video_path)
|
|
with open(cp_path, "w") as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
|
|
def mark_video_done(video_path: str, stats: dict):
|
|
"""Mark a video as fully processed."""
|
|
stats["done"] = True
|
|
save_checkpoint(video_path, stats)
|
|
|
|
|
|
def discover_videos() -> dict[str, list[str]]:
|
|
"""
|
|
Discover all camera directories and their video files.
|
|
Returns {camera_name: [video_paths]}.
|
|
"""
|
|
cameras = {}
|
|
video_dir = Path(cfg.VIDEO_DIR)
|
|
|
|
if not video_dir.exists():
|
|
logger.error(f"Video directory not found: {cfg.VIDEO_DIR}")
|
|
sys.exit(1)
|
|
|
|
for cam_dir in sorted(video_dir.iterdir()):
|
|
if cam_dir.is_dir():
|
|
videos = []
|
|
for ext in cfg.VIDEO_EXTENSIONS:
|
|
videos.extend(glob.glob(str(cam_dir / f"*{ext}")))
|
|
videos.extend(glob.glob(str(cam_dir / f"*{ext.upper()}")))
|
|
videos = sorted(set(videos))
|
|
if videos:
|
|
cameras[cam_dir.name] = videos
|
|
logger.info(f"Camera '{cam_dir.name}': {len(videos)} videos")
|
|
|
|
if not cameras:
|
|
logger.error("No video files found in any camera directory!")
|
|
sys.exit(1)
|
|
|
|
return cameras
|
|
|
|
|
|
def sanitize_camera_name(cam_name: str) -> str:
|
|
"""Create a filesystem-safe camera identifier."""
|
|
return cam_name.replace(" ", "_").replace("-", "_").replace("__", "_").strip("_").lower()
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Adaptive Sampler State Machine
|
|
# ──────────────────────────────────────────────
|
|
class AdaptiveSampler:
|
|
"""
|
|
State machine for adaptive frame sampling.
|
|
|
|
States:
|
|
- NORMAL: sample at BASE_FPS (1 fps)
|
|
- HIGH: sample at HIGH_FPS (3 fps) — person recently detected
|
|
- LOW: sample at LOW_FPS (0.5 fps) — long idle period
|
|
"""
|
|
|
|
def __init__(self, video_fps: float):
|
|
self.video_fps = video_fps
|
|
self.state = "NORMAL"
|
|
self.current_sample_fps = cfg.BASE_FPS
|
|
self.last_person_time = 0.0 # video timestamp of last person detection
|
|
self.no_person_streak = 0.0 # seconds since last person
|
|
|
|
def update(self, timestamp: float, person_detected: bool):
|
|
"""Update state based on detection result at given video timestamp."""
|
|
if person_detected:
|
|
self.last_person_time = timestamp
|
|
self.no_person_streak = 0.0
|
|
self.state = "HIGH"
|
|
self.current_sample_fps = cfg.HIGH_FPS
|
|
else:
|
|
self.no_person_streak = timestamp - self.last_person_time
|
|
|
|
if self.state == "HIGH":
|
|
# Stay high for HIGH_FPS_DURATION after last detection
|
|
if self.no_person_streak > cfg.HIGH_FPS_DURATION:
|
|
self.state = "NORMAL"
|
|
self.current_sample_fps = cfg.BASE_FPS
|
|
|
|
elif self.state == "NORMAL":
|
|
if self.no_person_streak > cfg.LOW_FPS_THRESHOLD:
|
|
self.state = "LOW"
|
|
self.current_sample_fps = cfg.LOW_FPS
|
|
|
|
# LOW stays LOW until a person is detected again
|
|
|
|
def get_frame_interval(self) -> int:
|
|
"""Return the frame interval (number of video frames to skip between samples)."""
|
|
interval = max(1, int(self.video_fps / self.current_sample_fps))
|
|
return interval
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Core Extraction
|
|
# ──────────────────────────────────────────────
|
|
def process_video(
|
|
model: YOLO,
|
|
video_path: str,
|
|
camera_name: str,
|
|
output_images_dir: str,
|
|
output_labels_dir: str,
|
|
global_stats: dict,
|
|
) -> dict:
|
|
"""
|
|
Process a single video file: extract frames with adaptive sampling,
|
|
detect persons with YOLO, save frames and labels.
|
|
|
|
Returns per-video stats dict.
|
|
"""
|
|
cam_safe = sanitize_camera_name(camera_name)
|
|
video_name = Path(video_path).stem
|
|
|
|
# Check if already done
|
|
checkpoint = load_checkpoint(video_path)
|
|
if checkpoint and checkpoint.get("done"):
|
|
logger.info(f" ⏭ Skipping (already done): {video_name}")
|
|
return checkpoint
|
|
|
|
cap = cv2.VideoCapture(video_path)
|
|
if not cap.isOpened():
|
|
logger.error(f" ✗ Cannot open video: {video_path}")
|
|
return {"error": "cannot_open", "video": video_name}
|
|
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
duration_sec = total_frames / video_fps if video_fps > 0 else 0
|
|
|
|
logger.info(f" ▶ Processing: {video_name}")
|
|
logger.info(f" FPS: {video_fps:.1f}, Frames: {total_frames}, Duration: {duration_sec:.0f}s")
|
|
|
|
# Resume from checkpoint
|
|
start_frame = 0
|
|
frame_counter = 0
|
|
if checkpoint:
|
|
start_frame = checkpoint.get("last_frame", 0)
|
|
frame_counter = checkpoint.get("frame_counter", 0)
|
|
logger.info(f" Resuming from frame {start_frame}")
|
|
|
|
sampler = AdaptiveSampler(video_fps)
|
|
stats = {
|
|
"video": video_name,
|
|
"camera": camera_name,
|
|
"total_frames": total_frames,
|
|
"frames_extracted": 0,
|
|
"frames_with_person": 0,
|
|
"frames_without_person": 0,
|
|
"last_frame": start_frame,
|
|
"frame_counter": frame_counter,
|
|
"done": False,
|
|
}
|
|
|
|
# Seek to start position if resuming
|
|
if start_frame > 0:
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
|
|
current_frame_idx = start_frame
|
|
batch_frames = [] # [(frame_idx, frame_array, timestamp), ...]
|
|
batch_save_info = [] # [(output_img_path, output_lbl_path), ...]
|
|
|
|
last_checkpoint_time = time.time()
|
|
|
|
while True:
|
|
# Check dataset size cap periodically
|
|
if stats["frames_extracted"] % 500 == 0 and stats["frames_extracted"] > 0:
|
|
current_size = get_dataset_size_gb(cfg.DATASET_DIR)
|
|
if current_size >= cfg.MAX_DATASET_SIZE_GB:
|
|
logger.warning(f" ⚠ Dataset size cap reached ({current_size:.1f} GB). Stopping extraction.")
|
|
global_stats["size_cap_reached"] = True
|
|
break
|
|
|
|
# Calculate next frame to sample
|
|
interval = sampler.get_frame_interval()
|
|
target_frame = current_frame_idx + interval
|
|
|
|
if target_frame >= total_frames:
|
|
break
|
|
|
|
# Seek to target frame
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
current_frame_idx = target_frame
|
|
timestamp = current_frame_idx / video_fps
|
|
|
|
# Build unique output filename
|
|
frame_counter += 1
|
|
frame_id = f"{cam_safe}_{frame_counter:06d}"
|
|
img_path = os.path.join(output_images_dir, f"{frame_id}.jpg")
|
|
lbl_path = os.path.join(output_labels_dir, f"{frame_id}.txt")
|
|
|
|
batch_frames.append((current_frame_idx, frame, timestamp))
|
|
batch_save_info.append((img_path, lbl_path))
|
|
|
|
# Process batch when full
|
|
if len(batch_frames) >= cfg.BATCH_SIZE:
|
|
persons_in_batch = _process_batch(
|
|
model, batch_frames, batch_save_info, sampler, stats
|
|
)
|
|
batch_frames.clear()
|
|
batch_save_info.clear()
|
|
|
|
# Checkpoint every 30 seconds
|
|
if time.time() - last_checkpoint_time > 30:
|
|
stats["last_frame"] = current_frame_idx
|
|
stats["frame_counter"] = frame_counter
|
|
save_checkpoint(video_path, stats)
|
|
last_checkpoint_time = time.time()
|
|
|
|
# Progress logging every 1000 extractions
|
|
if stats["frames_extracted"] % 1000 == 0 and stats["frames_extracted"] > 0:
|
|
pct = (current_frame_idx / total_frames) * 100
|
|
logger.info(
|
|
f" Progress: {pct:.1f}% | Extracted: {stats['frames_extracted']} | "
|
|
f"Persons: {stats['frames_with_person']} | Mode: {sampler.state} | "
|
|
f"FPS: {sampler.current_sample_fps}"
|
|
)
|
|
|
|
# Process remaining batch
|
|
if batch_frames:
|
|
_process_batch(model, batch_frames, batch_save_info, sampler, stats)
|
|
|
|
cap.release()
|
|
|
|
stats["last_frame"] = current_frame_idx
|
|
stats["frame_counter"] = frame_counter
|
|
mark_video_done(video_path, stats)
|
|
|
|
logger.info(
|
|
f" ✓ Done: {video_name} | Extracted: {stats['frames_extracted']} | "
|
|
f"With person: {stats['frames_with_person']} | "
|
|
f"Without: {stats['frames_without_person']}"
|
|
)
|
|
|
|
return stats
|
|
|
|
|
|
def _process_batch(
|
|
model: YOLO,
|
|
batch_frames: list,
|
|
batch_save_info: list,
|
|
sampler: AdaptiveSampler,
|
|
stats: dict,
|
|
) -> int:
|
|
"""
|
|
Run YOLO inference on a batch of frames, save images + labels.
|
|
Returns number of frames with person detections.
|
|
"""
|
|
frames = [f[1] for f in batch_frames]
|
|
timestamps = [f[2] for f in batch_frames]
|
|
|
|
# Run YOLO batch inference on GPU
|
|
results = model.predict(
|
|
source=frames,
|
|
conf=cfg.DETECTION_CONF,
|
|
iou=cfg.DETECTION_IOU,
|
|
classes=[cfg.PERSON_CLASS_ID],
|
|
device=0,
|
|
verbose=False,
|
|
half=True, # FP16 for speed
|
|
)
|
|
|
|
persons_count = 0
|
|
|
|
for i, result in enumerate(results):
|
|
img_path, lbl_path = batch_save_info[i]
|
|
timestamp = timestamps[i]
|
|
frame = frames[i]
|
|
|
|
# Get person detections
|
|
boxes = result.boxes
|
|
person_boxes = boxes[boxes.cls == cfg.PERSON_CLASS_ID]
|
|
has_person = len(person_boxes) > 0
|
|
|
|
# Update adaptive sampler
|
|
sampler.update(timestamp, has_person)
|
|
|
|
# Save frame as JPEG
|
|
cv2.imwrite(img_path, frame, [cv2.IMWRITE_JPEG_QUALITY, cfg.JPEG_QUALITY])
|
|
|
|
# Save YOLO-format labels
|
|
h, w = frame.shape[:2]
|
|
with open(lbl_path, "w") as f:
|
|
for box in person_boxes:
|
|
# Convert to YOLO format: class x_center y_center width height (normalized)
|
|
xyxy = box.xyxy[0].cpu().numpy()
|
|
x_center = ((xyxy[0] + xyxy[2]) / 2) / w
|
|
y_center = ((xyxy[1] + xyxy[3]) / 2) / h
|
|
box_w = (xyxy[2] - xyxy[0]) / w
|
|
box_h = (xyxy[3] - xyxy[1]) / h
|
|
conf = float(box.conf[0])
|
|
f.write(f"0 {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}\n")
|
|
|
|
stats["frames_extracted"] += 1
|
|
if has_person:
|
|
stats["frames_with_person"] += 1
|
|
persons_count += 1
|
|
else:
|
|
stats["frames_without_person"] += 1
|
|
|
|
return persons_count
|
|
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Main
|
|
# ──────────────────────────────────────────────
|
|
def extract_all() -> dict:
|
|
"""
|
|
Main entry point: discover videos, extract dataset.
|
|
Returns {camera_name: [video_stats]}.
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("DATASET EXTRACTION STARTED")
|
|
logger.info("=" * 60)
|
|
|
|
# Load YOLO model on GPU
|
|
logger.info(f"Loading detector model: {cfg.DETECTOR_MODEL}")
|
|
model = YOLO(cfg.DETECTOR_MODEL)
|
|
model.to("cuda")
|
|
logger.info("Model loaded on GPU ✓")
|
|
|
|
# Discover cameras and videos
|
|
cameras = discover_videos()
|
|
total_videos = sum(len(v) for v in cameras.values())
|
|
logger.info(f"Found {len(cameras)} cameras, {total_videos} videos total")
|
|
|
|
# Create output directories (flat — split happens later)
|
|
images_dir = os.path.join(cfg.DATASET_DIR, "images", "all")
|
|
labels_dir = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
|
os.makedirs(images_dir, exist_ok=True)
|
|
os.makedirs(labels_dir, exist_ok=True)
|
|
|
|
# Save camera mapping for split script
|
|
camera_mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
|
|
|
|
# Load existing mapping if resuming
|
|
if os.path.exists(camera_mapping_path):
|
|
with open(camera_mapping_path, "r") as f:
|
|
camera_mapping = json.load(f)
|
|
else:
|
|
camera_mapping = {}
|
|
|
|
global_stats = {"size_cap_reached": False}
|
|
all_stats = {}
|
|
video_num = 0
|
|
|
|
for cam_name, video_list in cameras.items():
|
|
cam_safe = sanitize_camera_name(cam_name)
|
|
logger.info(f"\n{'─' * 40}")
|
|
logger.info(f"Camera: {cam_name} ({len(video_list)} videos)")
|
|
logger.info(f"{'─' * 40}")
|
|
|
|
cam_stats = []
|
|
for video_path in video_list:
|
|
video_num += 1
|
|
logger.info(f"\n[{video_num}/{total_videos}]")
|
|
|
|
if global_stats["size_cap_reached"]:
|
|
logger.warning("Size cap reached — skipping remaining videos.")
|
|
break
|
|
|
|
vstats = process_video(
|
|
model, video_path, cam_name,
|
|
images_dir, labels_dir, global_stats
|
|
)
|
|
cam_stats.append(vstats)
|
|
|
|
# Track which frames belong to which camera
|
|
if cam_name not in camera_mapping:
|
|
camera_mapping[cam_name] = {"safe_name": cam_safe, "frames": []}
|
|
|
|
# Collect frame IDs for this camera
|
|
frame_prefix = cam_safe + "_"
|
|
existing_frames = camera_mapping[cam_name].get("frames", [])
|
|
new_frames = [
|
|
f for f in os.listdir(images_dir)
|
|
if f.startswith(frame_prefix) and f.endswith(".jpg")
|
|
]
|
|
camera_mapping[cam_name]["frames"] = sorted(set(existing_frames + new_frames))
|
|
|
|
# Save mapping after each video
|
|
with open(camera_mapping_path, "w") as f:
|
|
json.dump(camera_mapping, f, indent=2)
|
|
|
|
all_stats[cam_name] = cam_stats
|
|
|
|
if global_stats["size_cap_reached"]:
|
|
break
|
|
|
|
# Summary
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("EXTRACTION COMPLETE")
|
|
logger.info("=" * 60)
|
|
|
|
total_extracted = 0
|
|
total_persons = 0
|
|
for cam, stats_list in all_stats.items():
|
|
for s in stats_list:
|
|
total_extracted += s.get("frames_extracted", 0)
|
|
total_persons += s.get("frames_with_person", 0)
|
|
|
|
dataset_size = get_dataset_size_gb(cfg.DATASET_DIR)
|
|
logger.info(f"Total frames extracted: {total_extracted}")
|
|
logger.info(f"Frames with persons: {total_persons}")
|
|
logger.info(f"Dataset size: {dataset_size:.2f} GB")
|
|
logger.info(f"Cameras processed: {len(cameras)}")
|
|
|
|
# Save final stats
|
|
stats_path = os.path.join(cfg.DATASET_DIR, "extraction_stats.json")
|
|
with open(stats_path, "w") as f:
|
|
json.dump({
|
|
"total_extracted": total_extracted,
|
|
"total_with_persons": total_persons,
|
|
"dataset_size_gb": round(dataset_size, 2),
|
|
"cameras": {k: len(v) for k, v in cameras.items()},
|
|
"all_stats": all_stats,
|
|
}, f, indent=2, default=str)
|
|
|
|
return all_stats
|
|
|
|
|
|
if __name__ == "__main__":
|
|
extract_all()
|