233 lines
7.3 KiB
Python
233 lines
7.3 KiB
Python
"""
|
|
split_dataset.py
|
|
────────────────
|
|
Camera-level train/test split for the extracted dataset.
|
|
|
|
Splits ENTIRE cameras into train or test sets to prevent data leakage
|
|
from similar/consecutive surveillance frames.
|
|
|
|
- If >= 5 cameras: hold out 4 cameras for test
|
|
- If < 5 cameras: hold out 1 camera for test
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import shutil
|
|
import random
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
import pipeline_config as cfg
|
|
|
|
# ──────────────────────────────────────────────
|
|
# Logging
|
|
# ──────────────────────────────────────────────
|
|
os.makedirs(cfg.LOG_DIR, exist_ok=True)
|
|
log_file = os.path.join(cfg.LOG_DIR, f"split_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(log_file),
|
|
logging.StreamHandler(sys.stdout),
|
|
],
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_camera_mapping() -> dict:
|
|
"""Load the camera→frames mapping created during extraction."""
|
|
mapping_path = os.path.join(cfg.DATASET_DIR, "camera_mapping.json")
|
|
if not os.path.exists(mapping_path):
|
|
logger.error(f"Camera mapping not found: {mapping_path}")
|
|
logger.error("Run extract_dataset.py first!")
|
|
sys.exit(1)
|
|
|
|
with open(mapping_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def split_cameras(camera_names: list[str]) -> tuple[list[str], list[str]]:
|
|
"""
|
|
Split camera names into train and test sets.
|
|
Returns (train_cameras, test_cameras).
|
|
"""
|
|
random.seed(cfg.RANDOM_SEED)
|
|
|
|
n_cameras = len(camera_names)
|
|
if n_cameras < 2:
|
|
logger.error(f"Need at least 2 cameras for splitting, found {n_cameras}")
|
|
sys.exit(1)
|
|
|
|
# Determine how many cameras to hold out
|
|
if n_cameras >= 5:
|
|
n_test = min(cfg.TEST_CAMERAS, n_cameras - 1) # At least 1 for train
|
|
else:
|
|
n_test = 1
|
|
|
|
shuffled = camera_names.copy()
|
|
random.shuffle(shuffled)
|
|
|
|
test_cameras = shuffled[:n_test]
|
|
train_cameras = shuffled[n_test:]
|
|
|
|
return train_cameras, test_cameras
|
|
|
|
|
|
def move_frames(
|
|
camera_mapping: dict,
|
|
camera_names: list[str],
|
|
split_name: str,
|
|
src_images_dir: str,
|
|
src_labels_dir: str,
|
|
):
|
|
"""Move frames belonging to given cameras into the split directory."""
|
|
dst_images = os.path.join(cfg.DATASET_DIR, "images", split_name)
|
|
dst_labels = os.path.join(cfg.DATASET_DIR, "labels", split_name)
|
|
os.makedirs(dst_images, exist_ok=True)
|
|
os.makedirs(dst_labels, exist_ok=True)
|
|
|
|
total_moved = 0
|
|
|
|
for cam_name in camera_names:
|
|
cam_data = camera_mapping[cam_name]
|
|
frames = cam_data.get("frames", [])
|
|
|
|
for img_filename in frames:
|
|
# Image
|
|
src_img = os.path.join(src_images_dir, img_filename)
|
|
dst_img = os.path.join(dst_images, img_filename)
|
|
if os.path.exists(src_img):
|
|
shutil.move(src_img, dst_img)
|
|
|
|
# Label
|
|
lbl_filename = img_filename.replace(".jpg", ".txt")
|
|
src_lbl = os.path.join(src_labels_dir, lbl_filename)
|
|
dst_lbl = os.path.join(dst_labels, lbl_filename)
|
|
if os.path.exists(src_lbl):
|
|
shutil.move(src_lbl, dst_lbl)
|
|
|
|
total_moved += 1
|
|
|
|
logger.info(f" {split_name}: {total_moved} frames from {len(camera_names)} cameras")
|
|
return total_moved
|
|
|
|
|
|
def create_dataset_yaml(train_cameras: list[str], test_cameras: list[str]):
|
|
"""Create dataset.yaml for YOLO training."""
|
|
yaml_path = os.path.join(cfg.DATASET_DIR, "dataset.yaml")
|
|
|
|
# Use forward slashes for YOLO compatibility
|
|
dataset_path = cfg.DATASET_DIR.replace("\\", "/")
|
|
|
|
content = f"""# Auto-generated dataset config for person detection
|
|
# Generated: {datetime.now().isoformat()}
|
|
#
|
|
# Train cameras: {', '.join(train_cameras)}
|
|
# Test cameras: {', '.join(test_cameras)}
|
|
|
|
path: {dataset_path}
|
|
train: images/train
|
|
val: images/test
|
|
|
|
nc: 1
|
|
names: ['person']
|
|
"""
|
|
|
|
with open(yaml_path, "w") as f:
|
|
f.write(content)
|
|
|
|
logger.info(f"Created dataset.yaml at {yaml_path}")
|
|
return yaml_path
|
|
|
|
|
|
def split_dataset() -> dict:
|
|
"""
|
|
Main entry point: split extracted dataset by camera.
|
|
Returns split info dict.
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("DATASET SPLIT (CAMERA-LEVEL)")
|
|
logger.info("=" * 60)
|
|
|
|
# Load camera mapping
|
|
camera_mapping = load_camera_mapping()
|
|
camera_names = sorted(camera_mapping.keys())
|
|
logger.info(f"Found {len(camera_names)} cameras: {camera_names}")
|
|
|
|
# Count frames per camera
|
|
for cam in camera_names:
|
|
n_frames = len(camera_mapping[cam].get("frames", []))
|
|
logger.info(f" {cam}: {n_frames} frames")
|
|
|
|
# Split cameras
|
|
train_cameras, test_cameras = split_cameras(camera_names)
|
|
logger.info(f"\nTrain cameras ({len(train_cameras)}): {train_cameras}")
|
|
logger.info(f"Test cameras ({len(test_cameras)}): {test_cameras}")
|
|
|
|
# Source directories
|
|
src_images = os.path.join(cfg.DATASET_DIR, "images", "all")
|
|
src_labels = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
|
|
|
if not os.path.exists(src_images):
|
|
logger.error(f"Source images directory not found: {src_images}")
|
|
sys.exit(1)
|
|
|
|
# Move frames to train/test directories
|
|
logger.info("\nMoving frames to split directories...")
|
|
n_train = move_frames(camera_mapping, train_cameras, "train", src_images, src_labels)
|
|
n_test = move_frames(camera_mapping, test_cameras, "test", src_images, src_labels)
|
|
|
|
# Clean up empty 'all' directory
|
|
try:
|
|
remaining_imgs = os.listdir(src_images)
|
|
if not remaining_imgs:
|
|
os.rmdir(src_images)
|
|
src_labels_check = os.path.join(cfg.DATASET_DIR, "labels", "all")
|
|
remaining_lbls = os.listdir(src_labels_check)
|
|
if not remaining_lbls:
|
|
os.rmdir(src_labels_check)
|
|
logger.info("Cleaned up empty 'all' directories")
|
|
else:
|
|
logger.warning(f"{len(remaining_imgs)} orphan images left in 'all' directory")
|
|
except Exception as e:
|
|
logger.warning(f"Cleanup note: {e}")
|
|
|
|
# Create dataset.yaml
|
|
yaml_path = create_dataset_yaml(train_cameras, test_cameras)
|
|
|
|
# Create classes.txt
|
|
classes_path = os.path.join(cfg.DATASET_DIR, "classes.txt")
|
|
with open(classes_path, "w") as f:
|
|
f.write("person\n")
|
|
|
|
# Save split info
|
|
split_info = {
|
|
"train_cameras": train_cameras,
|
|
"test_cameras": test_cameras,
|
|
"train_frames": n_train,
|
|
"test_frames": n_test,
|
|
"random_seed": cfg.RANDOM_SEED,
|
|
"yaml_path": yaml_path,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
split_info_path = os.path.join(cfg.DATASET_DIR, "split_info.json")
|
|
with open(split_info_path, "w") as f:
|
|
json.dump(split_info, f, indent=2)
|
|
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("SPLIT COMPLETE")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Train: {n_train} frames from {len(train_cameras)} cameras")
|
|
logger.info(f"Test: {n_test} frames from {len(test_cameras)} cameras")
|
|
logger.info(f"Ratio: {n_train/(n_train+n_test)*100:.1f}% / {n_test/(n_train+n_test)*100:.1f}%")
|
|
|
|
return split_info
|
|
|
|
|
|
if __name__ == "__main__":
|
|
split_dataset()
|