""" split_dataset.py ──────────────── Camera-level train/test split for the extracted dataset. Splits ENTIRE cameras into train or test sets to prevent data leakage from similar/consecutive surveillance frames. - If >= 5 cameras: hold out 4 cameras for test - If < 5 cameras: hold out 1 camera for test """ import os import sys import json import shutil import random import logging from pathlib import Path from datetime import datetime import pipeline_config as cfg # ────────────────────────────────────────────── # Logging # ────────────────────────────────────────────── os.makedirs(cfg.LOG_DIR, exist_ok=True) log_file = os.path.join(cfg.LOG_DIR, f"split_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler(log_file), logging.StreamHandler(sys.stdout), ], ) logger = logging.getLogger(__name__) def load_camera_mapping() -> dict: """Load the camera→frames mapping created during extraction.""" mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json") if not os.path.exists(mapping_path): logger.error(f"Camera mapping not found: {mapping_path}") logger.error("Run extract_dataset.py first!") sys.exit(1) with open(mapping_path, "r") as f: return json.load(f) def split_cameras(camera_names: list[str]) -> tuple[list[str], list[str]]: """ Split camera names into train and test sets. Returns (train_cameras, test_cameras). """ random.seed(cfg.RANDOM_SEED) n_cameras = len(camera_names) if n_cameras < 2: logger.error(f"Need at least 2 cameras for splitting, found {n_cameras}") sys.exit(1) # Determine how many cameras to hold out if n_cameras >= 5: n_test = min(cfg.TEST_CAMERAS, n_cameras - 1) # At least 1 for train else: n_test = 1 shuffled = camera_names.copy() random.shuffle(shuffled) test_cameras = shuffled[:n_test] train_cameras = shuffled[n_test:] return train_cameras, test_cameras def move_frames( camera_mapping: dict, camera_names: list[str], split_name: str, src_images_dir: str, src_labels_dir: str, ): """Move frames belonging to given cameras into the split directory.""" dst_images = os.path.join(cfg.DATASET_DIR, "images", split_name) dst_labels = os.path.join(cfg.DATASET_DIR, "labels", split_name) os.makedirs(dst_images, exist_ok=True) os.makedirs(dst_labels, exist_ok=True) total_moved = 0 for cam_name in camera_names: cam_data = camera_mapping[cam_name] frames = cam_data.get("frames", []) for img_filename in frames: # Image src_img = os.path.join(src_images_dir, img_filename) dst_img = os.path.join(dst_images, img_filename) if os.path.exists(src_img): shutil.move(src_img, dst_img) # Label lbl_filename = img_filename.replace(".jpg", ".txt") src_lbl = os.path.join(src_labels_dir, lbl_filename) dst_lbl = os.path.join(dst_labels, lbl_filename) if os.path.exists(src_lbl): shutil.move(src_lbl, dst_lbl) total_moved += 1 logger.info(f" {split_name}: {total_moved} frames from {len(camera_names)} cameras") return total_moved def create_dataset_yaml(train_cameras: list[str], test_cameras: list[str]): """Create dataset.yaml for YOLO training.""" yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml") # Use forward slashes for YOLO compatibility dataset_path = cfg.DATASET_DIR.replace("\\", "/") content = f"""# Auto-generated dataset config for person detection # Generated: {datetime.now().isoformat()} # # Train cameras: {', '.join(train_cameras)} # Test cameras: {', '.join(test_cameras)} path: {dataset_path} train: images/train val: images/test nc: 1 names: ['person'] """ with open(yaml_path, "w") as f: f.write(content) logger.info(f"Created dataset.yaml at {yaml_path}") return yaml_path def split_dataset() -> dict: """ Main entry point: split extracted dataset by camera. Returns split info dict. """ logger.info("=" * 60) logger.info("DATASET SPLIT (CAMERA-LEVEL)") logger.info("=" * 60) # Load camera mapping camera_mapping = load_camera_mapping() camera_names = sorted(camera_mapping.keys()) logger.info(f"Found {len(camera_names)} cameras: {camera_names}") # Count frames per camera for cam in camera_names: n_frames = len(camera_mapping[cam].get("frames", [])) logger.info(f" {cam}: {n_frames} frames") # Split cameras train_cameras, test_cameras = split_cameras(camera_names) logger.info(f"\nTrain cameras ({len(train_cameras)}): {train_cameras}") logger.info(f"Test cameras ({len(test_cameras)}): {test_cameras}") # Source directories src_images = os.path.join(cfg.DATASET_DIR, "images", "all") src_labels = os.path.join(cfg.DATASET_DIR, "labels", "all") if not os.path.exists(src_images): logger.error(f"Source images directory not found: {src_images}") sys.exit(1) # Move frames to train/test directories logger.info("\nMoving frames to split directories...") n_train = move_frames(camera_mapping, train_cameras, "train", src_images, src_labels) n_test = move_frames(camera_mapping, test_cameras, "test", src_images, src_labels) # Clean up empty 'all' directory try: remaining_imgs = os.listdir(src_images) if not remaining_imgs: os.rmdir(src_images) src_labels_check = os.path.join(cfg.DATASET_DIR, "labels", "all") remaining_lbls = os.listdir(src_labels_check) if not remaining_lbls: os.rmdir(src_labels_check) logger.info("Cleaned up empty 'all' directories") else: logger.warning(f"{len(remaining_imgs)} orphan images left in 'all' directory") except Exception as e: logger.warning(f"Cleanup note: {e}") # Create dataset.yaml yaml_path = create_dataset_yaml(train_cameras, test_cameras) # Create classes.txt classes_path = os.path.join(cfg.DATASET_DIR, "classes.txt") with open(classes_path, "w") as f: f.write("person\n") # Save split info split_info = { "train_cameras": train_cameras, "test_cameras": test_cameras, "train_frames": n_train, "test_frames": n_test, "random_seed": cfg.RANDOM_SEED, "yaml_path": yaml_path, "timestamp": datetime.now().isoformat(), } split_info_path = os.path.join(cfg.DATASET_DIR, "split_info.json") with open(split_info_path, "w") as f: json.dump(split_info, f, indent=2) logger.info("\n" + "=" * 60) logger.info("SPLIT COMPLETE") logger.info("=" * 60) logger.info(f"Train: {n_train} frames from {len(train_cameras)} cameras") logger.info(f"Test: {n_test} frames from {len(test_cameras)} cameras") logger.info(f"Ratio: {n_train/(n_train+n_test)*100:.1f}% / {n_test/(n_train+n_test)*100:.1f}%") return split_info if __name__ == "__main__": split_dataset()