276 lines
11 KiB
Python
276 lines
11 KiB
Python
"""
|
||
Dataset & Augmentation Pipeline
|
||
================================
|
||
Handles:
|
||
• Random pairing (same-class → label=1, diff-class → label=0)
|
||
• Heavy augmentation for noise robustness:
|
||
- Random crops + resize → simulates cropping
|
||
- Random resize → simulates resolution differences
|
||
- Gaussian blur + noise → simulates sensor noise
|
||
- Color jitter → lighting variation
|
||
- Horizontal flip, rotation
|
||
"""
|
||
|
||
import os
|
||
import random
|
||
from pathlib import Path
|
||
from typing import Optional, Tuple, Callable
|
||
|
||
import torch
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from torchvision import transforms
|
||
from PIL import Image, ImageFilter, ImageEnhance
|
||
import numpy as np
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Custom Transform: Add Gaussian pixel noise
|
||
# ─────────────────────────────────────────────────────────────
|
||
class AddGaussianNoise:
|
||
def __init__(self, std: float = 0.05):
|
||
self.std = std
|
||
|
||
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
||
noise = torch.randn_like(tensor) * self.std
|
||
return torch.clamp(tensor + noise, 0.0, 1.0)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Augmentation Pipelines
|
||
# ─────────────────────────────────────────────────────────────
|
||
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
||
IMAGENET_STD = [0.229, 0.224, 0.225]
|
||
|
||
def get_train_transform(img_size: int = 224) -> transforms.Compose:
|
||
"""Heavy augmentation for training — robust to crops, resolution, noise."""
|
||
return transforms.Compose([
|
||
# Simulate resolution difference: random resize before crop
|
||
transforms.Resize((int(img_size * random.uniform(0.8, 1.4)),
|
||
int(img_size * random.uniform(0.8, 1.4)))
|
||
if False else img_size), # placeholder; see note below
|
||
transforms.RandomResizedCrop(
|
||
img_size,
|
||
scale=(0.5, 1.0), # aggressive cropping (50 – 100 % of image)
|
||
ratio=(0.75, 1.33),
|
||
),
|
||
transforms.RandomHorizontalFlip(p=0.5),
|
||
transforms.RandomRotation(degrees=15),
|
||
transforms.ColorJitter(
|
||
brightness=0.4,
|
||
contrast=0.4,
|
||
saturation=0.4,
|
||
hue=0.1,
|
||
),
|
||
transforms.RandomGrayscale(p=0.1),
|
||
transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3),
|
||
transforms.ToTensor(),
|
||
AddGaussianNoise(std=0.03),
|
||
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||
])
|
||
|
||
|
||
def get_val_transform(img_size: int = 224) -> transforms.Compose:
|
||
"""Deterministic transform for validation/inference."""
|
||
return transforms.Compose([
|
||
transforms.Resize((img_size, img_size)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||
])
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Siamese Pair Dataset
|
||
# ─────────────────────────────────────────────────────────────
|
||
class SiamesePairDataset(Dataset):
|
||
"""
|
||
Folder layout expected:
|
||
root/
|
||
class_A/ img1.jpg img2.jpg ...
|
||
class_B/ img1.jpg ...
|
||
|
||
Each __getitem__ returns (img1_tensor, img2_tensor, label)
|
||
label = 1 → same class
|
||
label = 0 → different class
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
root: str,
|
||
transform: Optional[Callable] = None,
|
||
pairs_per_epoch: int = 5000,
|
||
same_ratio: float = 0.5,
|
||
):
|
||
self.root = Path(root)
|
||
self.transform = transform
|
||
self.pairs_per_epoch = pairs_per_epoch
|
||
self.same_ratio = same_ratio
|
||
|
||
# Build class → [image paths] mapping
|
||
self.class_to_imgs: dict[str, list[Path]] = {}
|
||
for cls_dir in sorted(self.root.iterdir()):
|
||
if cls_dir.is_dir():
|
||
imgs = [
|
||
p for p in cls_dir.iterdir()
|
||
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||
]
|
||
if imgs:
|
||
self.class_to_imgs[cls_dir.name] = imgs
|
||
|
||
self.classes = list(self.class_to_imgs.keys())
|
||
assert len(self.classes) >= 2, "Need at least 2 classes."
|
||
|
||
# Pre-generate pairs for this epoch
|
||
self._pairs = self._generate_pairs()
|
||
|
||
def _generate_pairs(self):
|
||
pairs = []
|
||
n_same = int(self.pairs_per_epoch * self.same_ratio)
|
||
n_diff = self.pairs_per_epoch - n_same
|
||
|
||
# Same-class pairs (label = 1)
|
||
for _ in range(n_same):
|
||
cls = random.choice(self.classes)
|
||
imgs = self.class_to_imgs[cls]
|
||
if len(imgs) < 2:
|
||
a = b = imgs[0]
|
||
else:
|
||
a, b = random.sample(imgs, 2)
|
||
pairs.append((a, b, 1))
|
||
|
||
# Different-class pairs (label = 0)
|
||
for _ in range(n_diff):
|
||
c1, c2 = random.sample(self.classes, 2)
|
||
a = random.choice(self.class_to_imgs[c1])
|
||
b = random.choice(self.class_to_imgs[c2])
|
||
pairs.append((a, b, 0))
|
||
|
||
random.shuffle(pairs)
|
||
return pairs
|
||
|
||
def on_epoch_end(self):
|
||
"""Call between epochs to refresh random pairs."""
|
||
self._pairs = self._generate_pairs()
|
||
|
||
def __len__(self):
|
||
return len(self._pairs)
|
||
|
||
def __getitem__(self, idx):
|
||
path1, path2, label = self._pairs[idx]
|
||
img1 = Image.open(path1).convert("RGB")
|
||
img2 = Image.open(path2).convert("RGB")
|
||
|
||
if self.transform:
|
||
img1 = self.transform(img1)
|
||
img2 = self.transform(img2)
|
||
|
||
return img1, img2, torch.tensor(label, dtype=torch.float32)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Demo: Synthetic dataset (MNIST-style circles/squares)
|
||
# ─────────────────────────────────────────────────────────────
|
||
class SyntheticShapeDataset(Dataset):
|
||
"""
|
||
Generates synthetic image pairs on-the-fly.
|
||
Classes: circle, square, triangle (3 classes)
|
||
Useful for quick testing without real data.
|
||
"""
|
||
|
||
SHAPES = ["circle", "square", "triangle"]
|
||
|
||
def __init__(
|
||
self,
|
||
n_pairs: int = 2000,
|
||
img_size: int = 224,
|
||
same_ratio: float = 0.5,
|
||
augment: bool = True,
|
||
):
|
||
self.n_pairs = n_pairs
|
||
self.img_size = img_size
|
||
self.same_ratio = same_ratio
|
||
self.augment = augment
|
||
self.transform = get_train_transform(img_size) if augment else get_val_transform(img_size)
|
||
self._pairs = self._generate_pairs()
|
||
|
||
def _generate_pairs(self):
|
||
pairs = []
|
||
for _ in range(self.n_pairs):
|
||
if random.random() < self.same_ratio:
|
||
s = random.choice(self.SHAPES)
|
||
pairs.append((s, s, 1))
|
||
else:
|
||
s1, s2 = random.sample(self.SHAPES, 2)
|
||
pairs.append((s1, s2, 0))
|
||
return pairs
|
||
|
||
def _draw_shape(self, shape: str) -> Image.Image:
|
||
from PIL import ImageDraw
|
||
sz = self.img_size
|
||
img = Image.new("RGB", (sz, sz), color=(
|
||
random.randint(200, 255),
|
||
random.randint(200, 255),
|
||
random.randint(200, 255),
|
||
))
|
||
draw = ImageDraw.Draw(img)
|
||
margin = sz // 6
|
||
x0 = random.randint(margin, sz // 3)
|
||
y0 = random.randint(margin, sz // 3)
|
||
x1 = random.randint(sz * 2 // 3, sz - margin)
|
||
y1 = random.randint(sz * 2 // 3, sz - margin)
|
||
color = (random.randint(0, 100), random.randint(0, 100), random.randint(50, 150))
|
||
if shape == "circle":
|
||
draw.ellipse([x0, y0, x1, y1], fill=color)
|
||
elif shape == "square":
|
||
draw.rectangle([x0, y0, x1, y1], fill=color)
|
||
else: # triangle
|
||
draw.polygon([(sz // 2, y0), (x0, y1), (x1, y1)], fill=color)
|
||
return img
|
||
|
||
def __len__(self):
|
||
return self.n_pairs
|
||
|
||
def __getitem__(self, idx):
|
||
s1, s2, label = self._pairs[idx]
|
||
img1 = self.transform(self._draw_shape(s1))
|
||
img2 = self.transform(self._draw_shape(s2))
|
||
return img1, img2, torch.tensor(label, dtype=torch.float32)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# DataLoader factory
|
||
# ─────────────────────────────────────────────────────────────
|
||
def build_dataloaders(
|
||
train_root: Optional[str] = None,
|
||
val_root: Optional[str] = None,
|
||
img_size: int = 224,
|
||
batch_size: int = 32,
|
||
num_workers: int = 4,
|
||
pairs_per_epoch: int = 5000,
|
||
) -> Tuple[DataLoader, DataLoader]:
|
||
|
||
if train_root and Path(train_root).exists():
|
||
train_ds = SiamesePairDataset(
|
||
train_root,
|
||
transform=get_train_transform(img_size),
|
||
pairs_per_epoch=pairs_per_epoch,
|
||
)
|
||
val_ds = SiamesePairDataset(
|
||
val_root or train_root,
|
||
transform=get_val_transform(img_size),
|
||
pairs_per_epoch=pairs_per_epoch // 5,
|
||
)
|
||
else:
|
||
print("⚠️ No dataset root provided — using synthetic shapes for demo.")
|
||
train_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch, img_size=img_size, augment=True)
|
||
val_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch // 5, img_size=img_size, augment=False)
|
||
|
||
train_loader = DataLoader(
|
||
train_ds, batch_size=batch_size, shuffle=True,
|
||
num_workers=num_workers, pin_memory=True,
|
||
)
|
||
val_loader = DataLoader(
|
||
val_ds, batch_size=batch_size, shuffle=False,
|
||
num_workers=num_workers, pin_memory=True,
|
||
)
|
||
return train_loader, val_loader
|