""" Dataset & Augmentation Pipeline ================================ Handles: • Random pairing (same-class → label=1, diff-class → label=0) • Heavy augmentation for noise robustness: - Random crops + resize → simulates cropping - Random resize → simulates resolution differences - Gaussian blur + noise → simulates sensor noise - Color jitter → lighting variation - Horizontal flip, rotation """ import os import random from pathlib import Path from typing import Optional, Tuple, Callable import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image, ImageFilter, ImageEnhance import numpy as np # ───────────────────────────────────────────────────────────── # Custom Transform: Add Gaussian pixel noise # ───────────────────────────────────────────────────────────── class AddGaussianNoise: def __init__(self, std: float = 0.05): self.std = std def __call__(self, tensor: torch.Tensor) -> torch.Tensor: noise = torch.randn_like(tensor) * self.std return torch.clamp(tensor + noise, 0.0, 1.0) # ───────────────────────────────────────────────────────────── # Augmentation Pipelines # ───────────────────────────────────────────────────────────── IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_train_transform(img_size: int = 224) -> transforms.Compose: """Heavy augmentation for training — robust to crops, resolution, noise.""" return transforms.Compose([ # Simulate resolution difference: random resize before crop transforms.Resize((int(img_size * random.uniform(0.8, 1.4)), int(img_size * random.uniform(0.8, 1.4))) if False else img_size), # placeholder; see note below transforms.RandomResizedCrop( img_size, scale=(0.5, 1.0), # aggressive cropping (50 – 100 % of image) ratio=(0.75, 1.33), ), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=15), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, ), transforms.RandomGrayscale(p=0.1), transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3), transforms.ToTensor(), AddGaussianNoise(std=0.03), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) def get_val_transform(img_size: int = 224) -> transforms.Compose: """Deterministic transform for validation/inference.""" return transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # ───────────────────────────────────────────────────────────── # Siamese Pair Dataset # ───────────────────────────────────────────────────────────── class SiamesePairDataset(Dataset): """ Folder layout expected: root/ class_A/ img1.jpg img2.jpg ... class_B/ img1.jpg ... Each __getitem__ returns (img1_tensor, img2_tensor, label) label = 1 → same class label = 0 → different class """ def __init__( self, root: str, transform: Optional[Callable] = None, pairs_per_epoch: int = 5000, same_ratio: float = 0.5, ): self.root = Path(root) self.transform = transform self.pairs_per_epoch = pairs_per_epoch self.same_ratio = same_ratio # Build class → [image paths] mapping self.class_to_imgs: dict[str, list[Path]] = {} for cls_dir in sorted(self.root.iterdir()): if cls_dir.is_dir(): imgs = [ p for p in cls_dir.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"} ] if imgs: self.class_to_imgs[cls_dir.name] = imgs self.classes = list(self.class_to_imgs.keys()) assert len(self.classes) >= 2, "Need at least 2 classes." # Pre-generate pairs for this epoch self._pairs = self._generate_pairs() def _generate_pairs(self): pairs = [] n_same = int(self.pairs_per_epoch * self.same_ratio) n_diff = self.pairs_per_epoch - n_same # Same-class pairs (label = 1) for _ in range(n_same): cls = random.choice(self.classes) imgs = self.class_to_imgs[cls] if len(imgs) < 2: a = b = imgs[0] else: a, b = random.sample(imgs, 2) pairs.append((a, b, 1)) # Different-class pairs (label = 0) for _ in range(n_diff): c1, c2 = random.sample(self.classes, 2) a = random.choice(self.class_to_imgs[c1]) b = random.choice(self.class_to_imgs[c2]) pairs.append((a, b, 0)) random.shuffle(pairs) return pairs def on_epoch_end(self): """Call between epochs to refresh random pairs.""" self._pairs = self._generate_pairs() def __len__(self): return len(self._pairs) def __getitem__(self, idx): path1, path2, label = self._pairs[idx] img1 = Image.open(path1).convert("RGB") img2 = Image.open(path2).convert("RGB") if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, torch.tensor(label, dtype=torch.float32) # ───────────────────────────────────────────────────────────── # Demo: Synthetic dataset (MNIST-style circles/squares) # ───────────────────────────────────────────────────────────── class SyntheticShapeDataset(Dataset): """ Generates synthetic image pairs on-the-fly. Classes: circle, square, triangle (3 classes) Useful for quick testing without real data. """ SHAPES = ["circle", "square", "triangle"] def __init__( self, n_pairs: int = 2000, img_size: int = 224, same_ratio: float = 0.5, augment: bool = True, ): self.n_pairs = n_pairs self.img_size = img_size self.same_ratio = same_ratio self.augment = augment self.transform = get_train_transform(img_size) if augment else get_val_transform(img_size) self._pairs = self._generate_pairs() def _generate_pairs(self): pairs = [] for _ in range(self.n_pairs): if random.random() < self.same_ratio: s = random.choice(self.SHAPES) pairs.append((s, s, 1)) else: s1, s2 = random.sample(self.SHAPES, 2) pairs.append((s1, s2, 0)) return pairs def _draw_shape(self, shape: str) -> Image.Image: from PIL import ImageDraw sz = self.img_size img = Image.new("RGB", (sz, sz), color=( random.randint(200, 255), random.randint(200, 255), random.randint(200, 255), )) draw = ImageDraw.Draw(img) margin = sz // 6 x0 = random.randint(margin, sz // 3) y0 = random.randint(margin, sz // 3) x1 = random.randint(sz * 2 // 3, sz - margin) y1 = random.randint(sz * 2 // 3, sz - margin) color = (random.randint(0, 100), random.randint(0, 100), random.randint(50, 150)) if shape == "circle": draw.ellipse([x0, y0, x1, y1], fill=color) elif shape == "square": draw.rectangle([x0, y0, x1, y1], fill=color) else: # triangle draw.polygon([(sz // 2, y0), (x0, y1), (x1, y1)], fill=color) return img def __len__(self): return self.n_pairs def __getitem__(self, idx): s1, s2, label = self._pairs[idx] img1 = self.transform(self._draw_shape(s1)) img2 = self.transform(self._draw_shape(s2)) return img1, img2, torch.tensor(label, dtype=torch.float32) # ───────────────────────────────────────────────────────────── # DataLoader factory # ───────────────────────────────────────────────────────────── def build_dataloaders( train_root: Optional[str] = None, val_root: Optional[str] = None, img_size: int = 224, batch_size: int = 32, num_workers: int = 4, pairs_per_epoch: int = 5000, ) -> Tuple[DataLoader, DataLoader]: if train_root and Path(train_root).exists(): train_ds = SiamesePairDataset( train_root, transform=get_train_transform(img_size), pairs_per_epoch=pairs_per_epoch, ) val_ds = SiamesePairDataset( val_root or train_root, transform=get_val_transform(img_size), pairs_per_epoch=pairs_per_epoch // 5, ) else: print("⚠️ No dataset root provided — using synthetic shapes for demo.") train_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch, img_size=img_size, augment=True) val_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch // 5, img_size=img_size, augment=False) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) return train_loader, val_loader