listing-radar/model_export/dataset.py

276 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Dataset & Augmentation Pipeline
================================
Handles:
• Random pairing (same-class → label=1, diff-class → label=0)
• Heavy augmentation for noise robustness:
- Random crops + resize → simulates cropping
- Random resize → simulates resolution differences
- Gaussian blur + noise → simulates sensor noise
- Color jitter → lighting variation
- Horizontal flip, rotation
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, Callable
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFilter, ImageEnhance
import numpy as np
# ─────────────────────────────────────────────────────────────
# Custom Transform: Add Gaussian pixel noise
# ─────────────────────────────────────────────────────────────
class AddGaussianNoise:
def __init__(self, std: float = 0.05):
self.std = std
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
noise = torch.randn_like(tensor) * self.std
return torch.clamp(tensor + noise, 0.0, 1.0)
# ─────────────────────────────────────────────────────────────
# Augmentation Pipelines
# ─────────────────────────────────────────────────────────────
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def get_train_transform(img_size: int = 224) -> transforms.Compose:
"""Heavy augmentation for training — robust to crops, resolution, noise."""
return transforms.Compose([
# Simulate resolution difference: random resize before crop
transforms.Resize((int(img_size * random.uniform(0.8, 1.4)),
int(img_size * random.uniform(0.8, 1.4)))
if False else img_size), # placeholder; see note below
transforms.RandomResizedCrop(
img_size,
scale=(0.5, 1.0), # aggressive cropping (50 100 % of image)
ratio=(0.75, 1.33),
),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
),
transforms.RandomGrayscale(p=0.1),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3),
transforms.ToTensor(),
AddGaussianNoise(std=0.03),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
def get_val_transform(img_size: int = 224) -> transforms.Compose:
"""Deterministic transform for validation/inference."""
return transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# ─────────────────────────────────────────────────────────────
# Siamese Pair Dataset
# ─────────────────────────────────────────────────────────────
class SiamesePairDataset(Dataset):
"""
Folder layout expected:
root/
class_A/ img1.jpg img2.jpg ...
class_B/ img1.jpg ...
Each __getitem__ returns (img1_tensor, img2_tensor, label)
label = 1 → same class
label = 0 → different class
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pairs_per_epoch: int = 5000,
same_ratio: float = 0.5,
):
self.root = Path(root)
self.transform = transform
self.pairs_per_epoch = pairs_per_epoch
self.same_ratio = same_ratio
# Build class → [image paths] mapping
self.class_to_imgs: dict[str, list[Path]] = {}
for cls_dir in sorted(self.root.iterdir()):
if cls_dir.is_dir():
imgs = [
p for p in cls_dir.iterdir()
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
]
if imgs:
self.class_to_imgs[cls_dir.name] = imgs
self.classes = list(self.class_to_imgs.keys())
assert len(self.classes) >= 2, "Need at least 2 classes."
# Pre-generate pairs for this epoch
self._pairs = self._generate_pairs()
def _generate_pairs(self):
pairs = []
n_same = int(self.pairs_per_epoch * self.same_ratio)
n_diff = self.pairs_per_epoch - n_same
# Same-class pairs (label = 1)
for _ in range(n_same):
cls = random.choice(self.classes)
imgs = self.class_to_imgs[cls]
if len(imgs) < 2:
a = b = imgs[0]
else:
a, b = random.sample(imgs, 2)
pairs.append((a, b, 1))
# Different-class pairs (label = 0)
for _ in range(n_diff):
c1, c2 = random.sample(self.classes, 2)
a = random.choice(self.class_to_imgs[c1])
b = random.choice(self.class_to_imgs[c2])
pairs.append((a, b, 0))
random.shuffle(pairs)
return pairs
def on_epoch_end(self):
"""Call between epochs to refresh random pairs."""
self._pairs = self._generate_pairs()
def __len__(self):
return len(self._pairs)
def __getitem__(self, idx):
path1, path2, label = self._pairs[idx]
img1 = Image.open(path1).convert("RGB")
img2 = Image.open(path2).convert("RGB")
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2, torch.tensor(label, dtype=torch.float32)
# ─────────────────────────────────────────────────────────────
# Demo: Synthetic dataset (MNIST-style circles/squares)
# ─────────────────────────────────────────────────────────────
class SyntheticShapeDataset(Dataset):
"""
Generates synthetic image pairs on-the-fly.
Classes: circle, square, triangle (3 classes)
Useful for quick testing without real data.
"""
SHAPES = ["circle", "square", "triangle"]
def __init__(
self,
n_pairs: int = 2000,
img_size: int = 224,
same_ratio: float = 0.5,
augment: bool = True,
):
self.n_pairs = n_pairs
self.img_size = img_size
self.same_ratio = same_ratio
self.augment = augment
self.transform = get_train_transform(img_size) if augment else get_val_transform(img_size)
self._pairs = self._generate_pairs()
def _generate_pairs(self):
pairs = []
for _ in range(self.n_pairs):
if random.random() < self.same_ratio:
s = random.choice(self.SHAPES)
pairs.append((s, s, 1))
else:
s1, s2 = random.sample(self.SHAPES, 2)
pairs.append((s1, s2, 0))
return pairs
def _draw_shape(self, shape: str) -> Image.Image:
from PIL import ImageDraw
sz = self.img_size
img = Image.new("RGB", (sz, sz), color=(
random.randint(200, 255),
random.randint(200, 255),
random.randint(200, 255),
))
draw = ImageDraw.Draw(img)
margin = sz // 6
x0 = random.randint(margin, sz // 3)
y0 = random.randint(margin, sz // 3)
x1 = random.randint(sz * 2 // 3, sz - margin)
y1 = random.randint(sz * 2 // 3, sz - margin)
color = (random.randint(0, 100), random.randint(0, 100), random.randint(50, 150))
if shape == "circle":
draw.ellipse([x0, y0, x1, y1], fill=color)
elif shape == "square":
draw.rectangle([x0, y0, x1, y1], fill=color)
else: # triangle
draw.polygon([(sz // 2, y0), (x0, y1), (x1, y1)], fill=color)
return img
def __len__(self):
return self.n_pairs
def __getitem__(self, idx):
s1, s2, label = self._pairs[idx]
img1 = self.transform(self._draw_shape(s1))
img2 = self.transform(self._draw_shape(s2))
return img1, img2, torch.tensor(label, dtype=torch.float32)
# ─────────────────────────────────────────────────────────────
# DataLoader factory
# ─────────────────────────────────────────────────────────────
def build_dataloaders(
train_root: Optional[str] = None,
val_root: Optional[str] = None,
img_size: int = 224,
batch_size: int = 32,
num_workers: int = 4,
pairs_per_epoch: int = 5000,
) -> Tuple[DataLoader, DataLoader]:
if train_root and Path(train_root).exists():
train_ds = SiamesePairDataset(
train_root,
transform=get_train_transform(img_size),
pairs_per_epoch=pairs_per_epoch,
)
val_ds = SiamesePairDataset(
val_root or train_root,
transform=get_val_transform(img_size),
pairs_per_epoch=pairs_per_epoch // 5,
)
else:
print("⚠️ No dataset root provided — using synthetic shapes for demo.")
train_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch, img_size=img_size, augment=True)
val_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch // 5, img_size=img_size, augment=False)
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True,
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
return train_loader, val_loader