utopia-surveillance-tool/split_dataset.py

233 lines
7.3 KiB
Python

"""
split_dataset.py
────────────────
Camera-level train/test split for the extracted dataset.
Splits ENTIRE cameras into train or test sets to prevent data leakage
from similar/consecutive surveillance frames.
- If >= 5 cameras: hold out 4 cameras for test
- If < 5 cameras: hold out 1 camera for test
"""
import os
import sys
import json
import shutil
import random
import logging
from pathlib import Path
from datetime import datetime
import pipeline_config as cfg
# ──────────────────────────────────────────────
# Logging
# ──────────────────────────────────────────────
os.makedirs(cfg.LOG_DIR, exist_ok=True)
log_file = os.path.join(cfg.LOG_DIR, f"split_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
def load_camera_mapping() -> dict:
"""Load the camera→frames mapping created during extraction."""
mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
if not os.path.exists(mapping_path):
logger.error(f"Camera mapping not found: {mapping_path}")
logger.error("Run extract_dataset.py first!")
sys.exit(1)
with open(mapping_path, "r") as f:
return json.load(f)
def split_cameras(camera_names: list[str]) -> tuple[list[str], list[str]]:
"""
Split camera names into train and test sets.
Returns (train_cameras, test_cameras).
"""
random.seed(cfg.RANDOM_SEED)
n_cameras = len(camera_names)
if n_cameras < 2:
logger.error(f"Need at least 2 cameras for splitting, found {n_cameras}")
sys.exit(1)
# Determine how many cameras to hold out
if n_cameras >= 5:
n_test = min(cfg.TEST_CAMERAS, n_cameras - 1) # At least 1 for train
else:
n_test = 1
shuffled = camera_names.copy()
random.shuffle(shuffled)
test_cameras = shuffled[:n_test]
train_cameras = shuffled[n_test:]
return train_cameras, test_cameras
def move_frames(
camera_mapping: dict,
camera_names: list[str],
split_name: str,
src_images_dir: str,
src_labels_dir: str,
):
"""Move frames belonging to given cameras into the split directory."""
dst_images = os.path.join(cfg.DATASET_DIR, "images", split_name)
dst_labels = os.path.join(cfg.DATASET_DIR, "labels", split_name)
os.makedirs(dst_images, exist_ok=True)
os.makedirs(dst_labels, exist_ok=True)
total_moved = 0
for cam_name in camera_names:
cam_data = camera_mapping[cam_name]
frames = cam_data.get("frames", [])
for img_filename in frames:
# Image
src_img = os.path.join(src_images_dir, img_filename)
dst_img = os.path.join(dst_images, img_filename)
if os.path.exists(src_img):
shutil.move(src_img, dst_img)
# Label
lbl_filename = img_filename.replace(".jpg", ".txt")
src_lbl = os.path.join(src_labels_dir, lbl_filename)
dst_lbl = os.path.join(dst_labels, lbl_filename)
if os.path.exists(src_lbl):
shutil.move(src_lbl, dst_lbl)
total_moved += 1
logger.info(f" {split_name}: {total_moved} frames from {len(camera_names)} cameras")
return total_moved
def create_dataset_yaml(train_cameras: list[str], test_cameras: list[str]):
"""Create dataset.yaml for YOLO training."""
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
# Use forward slashes for YOLO compatibility
dataset_path = cfg.DATASET_DIR.replace("\\", "/")
content = f"""# Auto-generated dataset config for person detection
# Generated: {datetime.now().isoformat()}
#
# Train cameras: {', '.join(train_cameras)}
# Test cameras: {', '.join(test_cameras)}
path: {dataset_path}
train: images/train
val: images/test
nc: 1
names: ['person']
"""
with open(yaml_path, "w") as f:
f.write(content)
logger.info(f"Created dataset.yaml at {yaml_path}")
return yaml_path
def split_dataset() -> dict:
"""
Main entry point: split extracted dataset by camera.
Returns split info dict.
"""
logger.info("=" * 60)
logger.info("DATASET SPLIT (CAMERA-LEVEL)")
logger.info("=" * 60)
# Load camera mapping
camera_mapping = load_camera_mapping()
camera_names = sorted(camera_mapping.keys())
logger.info(f"Found {len(camera_names)} cameras: {camera_names}")
# Count frames per camera
for cam in camera_names:
n_frames = len(camera_mapping[cam].get("frames", []))
logger.info(f" {cam}: {n_frames} frames")
# Split cameras
train_cameras, test_cameras = split_cameras(camera_names)
logger.info(f"\nTrain cameras ({len(train_cameras)}): {train_cameras}")
logger.info(f"Test cameras ({len(test_cameras)}): {test_cameras}")
# Source directories
src_images = os.path.join(cfg.DATASET_DIR, "images", "all")
src_labels = os.path.join(cfg.DATASET_DIR, "labels", "all")
if not os.path.exists(src_images):
logger.error(f"Source images directory not found: {src_images}")
sys.exit(1)
# Move frames to train/test directories
logger.info("\nMoving frames to split directories...")
n_train = move_frames(camera_mapping, train_cameras, "train", src_images, src_labels)
n_test = move_frames(camera_mapping, test_cameras, "test", src_images, src_labels)
# Clean up empty 'all' directory
try:
remaining_imgs = os.listdir(src_images)
if not remaining_imgs:
os.rmdir(src_images)
src_labels_check = os.path.join(cfg.DATASET_DIR, "labels", "all")
remaining_lbls = os.listdir(src_labels_check)
if not remaining_lbls:
os.rmdir(src_labels_check)
logger.info("Cleaned up empty 'all' directories")
else:
logger.warning(f"{len(remaining_imgs)} orphan images left in 'all' directory")
except Exception as e:
logger.warning(f"Cleanup note: {e}")
# Create dataset.yaml
yaml_path = create_dataset_yaml(train_cameras, test_cameras)
# Create classes.txt
classes_path = os.path.join(cfg.DATASET_DIR, "classes.txt")
with open(classes_path, "w") as f:
f.write("person\n")
# Save split info
split_info = {
"train_cameras": train_cameras,
"test_cameras": test_cameras,
"train_frames": n_train,
"test_frames": n_test,
"random_seed": cfg.RANDOM_SEED,
"yaml_path": yaml_path,
"timestamp": datetime.now().isoformat(),
}
split_info_path = os.path.join(cfg.DATASET_DIR, "split_info.json")
with open(split_info_path, "w") as f:
json.dump(split_info, f, indent=2)
logger.info("\n" + "=" * 60)
logger.info("SPLIT COMPLETE")
logger.info("=" * 60)
logger.info(f"Train: {n_train} frames from {len(train_cameras)} cameras")
logger.info(f"Test: {n_test} frames from {len(test_cameras)} cameras")
logger.info(f"Ratio: {n_train/(n_train+n_test)*100:.1f}% / {n_test/(n_train+n_test)*100:.1f}%")
return split_info
if __name__ == "__main__":
split_dataset()