diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..3da93a3 --- /dev/null +++ b/dataset.py @@ -0,0 +1,275 @@ +""" +Dataset & Augmentation Pipeline +================================ +Handles: + • Random pairing (same-class → label=1, diff-class → label=0) + • Heavy augmentation for noise robustness: + - Random crops + resize → simulates cropping + - Random resize → simulates resolution differences + - Gaussian blur + noise → simulates sensor noise + - Color jitter → lighting variation + - Horizontal flip, rotation +""" + +import os +import random +from pathlib import Path +from typing import Optional, Tuple, Callable + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image, ImageFilter, ImageEnhance +import numpy as np + + +# ───────────────────────────────────────────────────────────── +# Custom Transform: Add Gaussian pixel noise +# ───────────────────────────────────────────────────────────── +class AddGaussianNoise: + def __init__(self, std: float = 0.05): + self.std = std + + def __call__(self, tensor: torch.Tensor) -> torch.Tensor: + noise = torch.randn_like(tensor) * self.std + return torch.clamp(tensor + noise, 0.0, 1.0) + + +# ───────────────────────────────────────────────────────────── +# Augmentation Pipelines +# ───────────────────────────────────────────────────────────── +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + +def get_train_transform(img_size: int = 224) -> transforms.Compose: + """Heavy augmentation for training — robust to crops, resolution, noise.""" + return transforms.Compose([ + # Simulate resolution difference: random resize before crop + transforms.Resize((int(img_size * random.uniform(0.8, 1.4)), + int(img_size * random.uniform(0.8, 1.4))) + if False else img_size), # placeholder; see note below + transforms.RandomResizedCrop( + img_size, + scale=(0.5, 1.0), # aggressive cropping (50 – 100 % of image) + ratio=(0.75, 1.33), + ), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomRotation(degrees=15), + transforms.ColorJitter( + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1, + ), + transforms.RandomGrayscale(p=0.1), + transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3), + transforms.ToTensor(), + AddGaussianNoise(std=0.03), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ]) + + +def get_val_transform(img_size: int = 224) -> transforms.Compose: + """Deterministic transform for validation/inference.""" + return transforms.Compose([ + transforms.Resize((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ]) + + +# ───────────────────────────────────────────────────────────── +# Siamese Pair Dataset +# ───────────────────────────────────────────────────────────── +class SiamesePairDataset(Dataset): + """ + Folder layout expected: + root/ + class_A/ img1.jpg img2.jpg ... + class_B/ img1.jpg ... + + Each __getitem__ returns (img1_tensor, img2_tensor, label) + label = 1 → same class + label = 0 → different class + """ + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pairs_per_epoch: int = 5000, + same_ratio: float = 0.5, + ): + self.root = Path(root) + self.transform = transform + self.pairs_per_epoch = pairs_per_epoch + self.same_ratio = same_ratio + + # Build class → [image paths] mapping + self.class_to_imgs: dict[str, list[Path]] = {} + for cls_dir in sorted(self.root.iterdir()): + if cls_dir.is_dir(): + imgs = [ + p for p in cls_dir.iterdir() + if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"} + ] + if imgs: + self.class_to_imgs[cls_dir.name] = imgs + + self.classes = list(self.class_to_imgs.keys()) + assert len(self.classes) >= 2, "Need at least 2 classes." + + # Pre-generate pairs for this epoch + self._pairs = self._generate_pairs() + + def _generate_pairs(self): + pairs = [] + n_same = int(self.pairs_per_epoch * self.same_ratio) + n_diff = self.pairs_per_epoch - n_same + + # Same-class pairs (label = 1) + for _ in range(n_same): + cls = random.choice(self.classes) + imgs = self.class_to_imgs[cls] + if len(imgs) < 2: + a = b = imgs[0] + else: + a, b = random.sample(imgs, 2) + pairs.append((a, b, 1)) + + # Different-class pairs (label = 0) + for _ in range(n_diff): + c1, c2 = random.sample(self.classes, 2) + a = random.choice(self.class_to_imgs[c1]) + b = random.choice(self.class_to_imgs[c2]) + pairs.append((a, b, 0)) + + random.shuffle(pairs) + return pairs + + def on_epoch_end(self): + """Call between epochs to refresh random pairs.""" + self._pairs = self._generate_pairs() + + def __len__(self): + return len(self._pairs) + + def __getitem__(self, idx): + path1, path2, label = self._pairs[idx] + img1 = Image.open(path1).convert("RGB") + img2 = Image.open(path2).convert("RGB") + + if self.transform: + img1 = self.transform(img1) + img2 = self.transform(img2) + + return img1, img2, torch.tensor(label, dtype=torch.float32) + + +# ───────────────────────────────────────────────────────────── +# Demo: Synthetic dataset (MNIST-style circles/squares) +# ───────────────────────────────────────────────────────────── +class SyntheticShapeDataset(Dataset): + """ + Generates synthetic image pairs on-the-fly. + Classes: circle, square, triangle (3 classes) + Useful for quick testing without real data. + """ + + SHAPES = ["circle", "square", "triangle"] + + def __init__( + self, + n_pairs: int = 2000, + img_size: int = 224, + same_ratio: float = 0.5, + augment: bool = True, + ): + self.n_pairs = n_pairs + self.img_size = img_size + self.same_ratio = same_ratio + self.augment = augment + self.transform = get_train_transform(img_size) if augment else get_val_transform(img_size) + self._pairs = self._generate_pairs() + + def _generate_pairs(self): + pairs = [] + for _ in range(self.n_pairs): + if random.random() < self.same_ratio: + s = random.choice(self.SHAPES) + pairs.append((s, s, 1)) + else: + s1, s2 = random.sample(self.SHAPES, 2) + pairs.append((s1, s2, 0)) + return pairs + + def _draw_shape(self, shape: str) -> Image.Image: + from PIL import ImageDraw + sz = self.img_size + img = Image.new("RGB", (sz, sz), color=( + random.randint(200, 255), + random.randint(200, 255), + random.randint(200, 255), + )) + draw = ImageDraw.Draw(img) + margin = sz // 6 + x0 = random.randint(margin, sz // 3) + y0 = random.randint(margin, sz // 3) + x1 = random.randint(sz * 2 // 3, sz - margin) + y1 = random.randint(sz * 2 // 3, sz - margin) + color = (random.randint(0, 100), random.randint(0, 100), random.randint(50, 150)) + if shape == "circle": + draw.ellipse([x0, y0, x1, y1], fill=color) + elif shape == "square": + draw.rectangle([x0, y0, x1, y1], fill=color) + else: # triangle + draw.polygon([(sz // 2, y0), (x0, y1), (x1, y1)], fill=color) + return img + + def __len__(self): + return self.n_pairs + + def __getitem__(self, idx): + s1, s2, label = self._pairs[idx] + img1 = self.transform(self._draw_shape(s1)) + img2 = self.transform(self._draw_shape(s2)) + return img1, img2, torch.tensor(label, dtype=torch.float32) + + +# ───────────────────────────────────────────────────────────── +# DataLoader factory +# ───────────────────────────────────────────────────────────── +def build_dataloaders( + train_root: Optional[str] = None, + val_root: Optional[str] = None, + img_size: int = 224, + batch_size: int = 32, + num_workers: int = 4, + pairs_per_epoch: int = 5000, +) -> Tuple[DataLoader, DataLoader]: + + if train_root and Path(train_root).exists(): + train_ds = SiamesePairDataset( + train_root, + transform=get_train_transform(img_size), + pairs_per_epoch=pairs_per_epoch, + ) + val_ds = SiamesePairDataset( + val_root or train_root, + transform=get_val_transform(img_size), + pairs_per_epoch=pairs_per_epoch // 5, + ) + else: + print("⚠️ No dataset root provided — using synthetic shapes for demo.") + train_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch, img_size=img_size, augment=True) + val_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch // 5, img_size=img_size, augment=False) + + train_loader = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True, + ) + val_loader = DataLoader( + val_ds, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True, + ) + return train_loader, val_loader diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..565880d --- /dev/null +++ b/inference.py @@ -0,0 +1,202 @@ +""" +Inference Module +================ +Load a trained Siamese Network and compare images. + +Usage: + python inference.py --img1 a.jpg --img2 b.jpg --checkpoint checkpoints/best_model.pth + python inference.py --query logo.png --gallery_dir /images/ # find best match +""" + +import argparse +from pathlib import Path +from typing import List, Tuple, Optional + +import torch +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np + +from model import SiameseNet +from dataset import get_val_transform, IMAGENET_MEAN, IMAGENET_STD + + +# ───────────────────────────────────────────────────────────── +# Predictor class +# ───────────────────────────────────────────────────────────── +class SiamesePredictor: + """ + High-level inference wrapper. + """ + + def __init__( + self, + checkpoint_path: Optional[str] = None, + embedding_dim: int = 128, + device: Optional[str] = None, + threshold: float = 0.6, + ): + self.device = torch.device(device or ( + "cuda" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else + "cpu" + )) + self.threshold = threshold + self.transform = get_val_transform(224) + + # Build model + self.model = SiameseNet(embedding_dim=embedding_dim, pretrained=False) + self.model.to(self.device) + + if checkpoint_path and Path(checkpoint_path).exists(): + ckpt = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(ckpt["model_state"]) + print(f"✅ Loaded checkpoint from {checkpoint_path}") + print(f" Trained for {ckpt.get('epoch', '?')} epochs | " + f"val_acc = {ckpt.get('val_acc', 0):.3f}") + else: + print("⚠️ No checkpoint — running with random weights (demo only).") + + self.model.eval() + + def _load(self, path_or_pil) -> torch.Tensor: + if isinstance(path_or_pil, (str, Path)): + img = Image.open(path_or_pil).convert("RGB") + else: + img = path_or_pil.convert("RGB") + return self.transform(img).unsqueeze(0).to(self.device) + + @torch.no_grad() + def compare(self, img1, img2) -> dict: + """ + Compare two images. + Returns dict with score, match, and euclidean distance. + """ + t1 = self._load(img1) + t2 = self._load(img2) + score, emb1, emb2 = self.model(t1, t2) + s = score.item() + dist = F.pairwise_distance(emb1, emb2).item() + return { + "similarity_score": round(s, 4), + "euclidean_dist": round(dist, 4), + "match": s >= self.threshold, + "confidence": "HIGH" if abs(s - 0.5) > 0.3 else + "MED" if abs(s - 0.5) > 0.15 else "LOW", + } + + @torch.no_grad() + def embed(self, img) -> np.ndarray: + """Get embedding vector for an image.""" + t = self._load(img) + return self.model.get_embedding(t).cpu().numpy()[0] + + @torch.no_grad() + def find_best_match( + self, query_img, gallery: List, top_k: int = 5 + ) -> List[Tuple[float, int]]: + """ + Find top-k most similar images from a gallery list. + gallery: list of image paths or PIL images. + Returns: list of (score, index) sorted descending. + """ + q_emb = torch.from_numpy(self.embed(query_img)).to(self.device) + scores = [] + for idx, img in enumerate(gallery): + g_emb = torch.from_numpy(self.embed(img)).to(self.device) + cos = F.cosine_similarity(q_emb.unsqueeze(0), g_emb.unsqueeze(0)).item() + s = (cos + 1.0) / 2.0 + scores.append((s, idx)) + scores.sort(reverse=True) + return scores[:top_k] + + +# ───────────────────────────────────────────────────────────── +# Visualisation helper +# ───────────────────────────────────────────────────────────── +def visualise_comparison(img1_path, img2_path, result: dict, save_path="comparison.png"): + img1 = Image.open(img1_path).convert("RGB") + img2 = Image.open(img2_path).convert("RGB") + + fig, axes = plt.subplots(1, 3, figsize=(14, 5)) + fig.patch.set_facecolor("#0f0f1a") + + for ax in axes: + ax.set_facecolor("#0f0f1a") + ax.axis("off") + + axes[0].imshow(img1) + axes[0].set_title("Image 1", color="white", fontsize=13, pad=10) + + axes[2].imshow(img2) + axes[2].set_title("Image 2", color="white", fontsize=13, pad=10) + + # Middle panel: result + color = "#00e676" if result["match"] else "#ff1744" + symbol = "✓ MATCH" if result["match"] else "✗ NO MATCH" + axes[1].set_xlim(0, 1); axes[1].set_ylim(0, 1) + axes[1].add_patch(mpatches.FancyBboxPatch( + (0.1, 0.2), 0.8, 0.6, boxstyle="round,pad=0.05", + linewidth=3, edgecolor=color, facecolor="#1a1a2e" + )) + axes[1].text(0.5, 0.75, symbol, ha="center", va="center", fontsize=20, + fontweight="bold", color=color) + axes[1].text(0.5, 0.55, f"Score: {result['similarity_score']:.3f}", + ha="center", va="center", fontsize=14, color="white") + axes[1].text(0.5, 0.40, f"Dist: {result['euclidean_dist']:.3f}", + ha="center", va="center", fontsize=11, color="#aaaaaa") + axes[1].text(0.5, 0.28, f"Conf: {result['confidence']}", + ha="center", va="center", fontsize=11, color="#aaaaaa") + + plt.tight_layout(pad=2) + plt.savefig(save_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"📸 Visualisation saved → {save_path}") + + +# ───────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────── +if __name__ == "__main__": + p = argparse.ArgumentParser(description="Siamese Network Inference") + p.add_argument("--img1", required=True, help="Path to image 1") + p.add_argument("--img2", default=None, help="Path to image 2 (comparison mode)") + p.add_argument("--gallery_dir", default=None, help="Directory to search (gallery mode)") + p.add_argument("--checkpoint", default="checkpoints/best_model.pth") + p.add_argument("--threshold", type=float, default=0.6, help="Match threshold") + p.add_argument("--top_k", type=int, default=5) + p.add_argument("--output", default="comparison.png") + args = p.parse_args() + + predictor = SiamesePredictor( + checkpoint_path=args.checkpoint, + threshold=args.threshold, + ) + + if args.img2: + # ── Compare two images ──────────────────────────────── + result = predictor.compare(args.img1, args.img2) + print("\n" + "═" * 45) + print(f" Similarity score : {result['similarity_score']}") + print(f" Euclidean dist : {result['euclidean_dist']}") + print(f" Match : {'YES ✓' if result['match'] else 'NO ✗'}") + print(f" Confidence : {result['confidence']}") + print("═" * 45) + visualise_comparison(args.img1, args.img2, result, args.output) + + elif args.gallery_dir: + # ── Gallery search ──────────────────────────────────── + exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} + gallery = [p for p in Path(args.gallery_dir).iterdir() if p.suffix.lower() in exts] + matches = predictor.find_best_match(args.img1, gallery, top_k=args.top_k) + print(f"\nTop-{args.top_k} matches for {args.img1}:") + for rank, (score, idx) in enumerate(matches, 1): + verdict = "✓ MATCH" if score >= args.threshold else " diff " + print(f" {rank}. {verdict} score={score:.4f} {gallery[idx]}") + else: + print("Provide --img2 or --gallery_dir") diff --git a/model.py b/model.py new file mode 100644 index 0000000..d59de4c --- /dev/null +++ b/model.py @@ -0,0 +1,113 @@ +""" +Siamese Neural Network for Image Similarity Matching +====================================================== +Robust to: cropping, resolution differences, noise, rotation, color jitter. + +Architecture: + - Shared CNN encoder (ResNet-18 backbone, pretrained) + - L2-normalized embedding head + - Contrastive Loss / Binary Cross-Entropy with cosine similarity +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + + +# ───────────────────────────────────────────────────────────── +# Embedding Network (shared weights for both branches) +# ───────────────────────────────────────────────────────────── +class EmbeddingNet(nn.Module): + """ + ResNet-18 backbone with a custom projection head. + Outputs a 128-dim L2-normalised embedding. + """ + + def __init__(self, embedding_dim: int = 128, pretrained: bool = True): + super().__init__() + + # Load pretrained ResNet-18 + weights = models.ResNet18_Weights.DEFAULT if pretrained else None + backbone = models.resnet18(weights=weights) + + # Remove the final FC layer; keep everything up to avgpool + self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1]) + + # Projection head: 512 → 256 → embedding_dim + self.projector = nn.Sequential( + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Dropout(0.3), + nn.Linear(256, embedding_dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, 3, H, W) + feats = self.feature_extractor(x) # (B, 512, 1, 1) + feats = feats.view(feats.size(0), -1) # (B, 512) + emb = self.projector(feats) # (B, embedding_dim) + return F.normalize(emb, p=2, dim=1) # L2-normalised + + +# ───────────────────────────────────────────────────────────── +# Siamese Network +# ───────────────────────────────────────────────────────────── +class SiameseNet(nn.Module): + """ + Takes two images, returns: + • similarity score ∈ [0, 1] (1 = same, 0 = different) + • both L2-normalised embeddings + """ + + def __init__(self, embedding_dim: int = 128, pretrained: bool = True): + super().__init__() + self.encoder = EmbeddingNet(embedding_dim, pretrained) + + def forward(self, img1: torch.Tensor, img2: torch.Tensor): + emb1 = self.encoder(img1) + emb2 = self.encoder(img2) + # Cosine similarity → [−1, 1], rescale to [0, 1] + cos_sim = F.cosine_similarity(emb1, emb2, dim=1) + score = (cos_sim + 1.0) / 2.0 # (B,) + return score, emb1, emb2 + + def get_embedding(self, img: torch.Tensor) -> torch.Tensor: + """Get embedding for a single image (useful at inference time).""" + return self.encoder(img) + + +# ───────────────────────────────────────────────────────────── +# Loss Functions +# ───────────────────────────────────────────────────────────── +class ContrastiveLoss(nn.Module): + """ + Contrastive Loss (LeCun 2005). + label = 1 → same pair (pull embeddings together) + label = 0 → diff pair (push embeddings apart, beyond margin) + """ + + def __init__(self, margin: float = 1.0): + super().__init__() + self.margin = margin + + def forward(self, emb1, emb2, label): + dist = F.pairwise_distance(emb1, emb2) # Euclidean distance + same_loss = label * dist.pow(2) + diff_loss = (1 - label) * F.relu(self.margin - dist).pow(2) + return (same_loss + diff_loss).mean() + + +class BCECosineLoss(nn.Module): + """ + Binary Cross-Entropy on cosine-similarity score. + Simple and very effective for this task. + """ + + def __init__(self): + super().__init__() + self.bce = nn.BCELoss() + + def forward(self, score, label): + return self.bce(score, label.float()) diff --git a/p_hash.py b/p_hash.py new file mode 100644 index 0000000..80bd463 --- /dev/null +++ b/p_hash.py @@ -0,0 +1,20 @@ +from PIL import Image +import imagehash +import os + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +img1 = imagehash.phash(Image.open("u1.png")) +img2 = imagehash.phash(Image.open("u2.png")) +def check_image(img1, img2): + print(img1) + print(img2) + distance = img1 - img2 + print(distance) + if distance > 10: # typical threshold ~5-12 depending on use case + print("→ pHash says different → Early exit") + else: + print("→ pHash passed → Proceeding to deep embeddings") + + +check_image(img1, img2) diff --git a/train.py b/train.py new file mode 100644 index 0000000..f769840 --- /dev/null +++ b/train.py @@ -0,0 +1,195 @@ +""" +Training Script +=============== +Usage: + python train.py # synthetic demo + python train.py --train_dir data/train --val_dir data/val + python train.py --train_dir data/train --epochs 30 --batch_size 64 +""" + +import argparse +import os +import time +from pathlib import Path + +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import CosineAnnealingLR +from tqdm import tqdm +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +from model import SiameseNet, ContrastiveLoss, BCECosineLoss +from dataset import build_dataloaders +from dotenv import load_dotenv +load_dotenv() + +# ───────────────────────────────────────────────────────────── +# Metrics +# ───────────────────────────────────────────────────────────── +def compute_accuracy(scores: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5): + preds = (scores >= threshold).float() + correct = (preds == labels).sum().item() + return correct / len(labels) + + +# ───────────────────────────────────────────────────────────── +# One epoch +# ───────────────────────────────────────────────────────────── +def run_epoch(model, loader, criterion, optimizer, device, train: bool): + model.train() if train else model.eval() + total_loss, total_acc, n = 0.0, 0.0, 0 + + ctx = torch.enable_grad() if train else torch.no_grad() + with ctx: + for img1, img2, labels in tqdm(loader, desc="train" if train else "val ", leave=False): + img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) + + scores, emb1, emb2 = model(img1, img2) + + # Use BCECosineLoss by default; swap to ContrastiveLoss if you prefer + loss = criterion(scores, labels) + + if train: + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + acc = compute_accuracy(scores.detach().cpu(), labels.cpu()) + bs = img1.size(0) + total_loss += loss.item() * bs + total_acc += acc * bs + n += bs + + return total_loss / n, total_acc / n + + +# ───────────────────────────────────────────────────────────── +# Training loop +# ───────────────────────────────────────────────────────────── +def train(args): + device = torch.device( + "cuda" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else + "cpu" + ) + print(f"🖥️ Device: {device}") + args.train_dir=os.getenv("DATASET") + args.val_dir=os.getenv("VAL_DATASET") + # ── Data ────────────────────────────────────────────────── + train_loader, val_loader = build_dataloaders( + train_root = args.train_dir, + val_root = args.val_dir, + img_size = args.img_size, + batch_size = args.batch_size, + num_workers = args.num_workers, + pairs_per_epoch = args.pairs_per_epoch, + ) + print(f"📦 Train batches: {len(train_loader)} | Val batches: {len(val_loader)}") + + # ── Model ───────────────────────────────────────────────── + model = SiameseNet(embedding_dim=args.embedding_dim, pretrained=True).to(device) + criterion = BCECosineLoss() + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) + scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + + # ── Output dir ──────────────────────────────────────────── + out = Path(args.output_dir) + out.mkdir(parents=True, exist_ok=True) + + history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []} + best_val_loss = float("inf") + patience_counter = 0 + + print("\n🚀 Training started\n") + for epoch in range(1, args.epochs + 1): + t0 = time.time() + + tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, train=True) + va_loss, va_acc = run_epoch(model, val_loader, criterion, None, device, train=False) + + scheduler.step() + + history["train_loss"].append(tr_loss) + history["val_loss"].append(va_loss) + history["train_acc"].append(tr_acc) + history["val_acc"].append(va_acc) + + elapsed = time.time() - t0 + print( + f"Epoch {epoch:3d}/{args.epochs} │ " + f"Train loss={tr_loss:.4f} acc={tr_acc:.3f} │ " + f"Val loss={va_loss:.4f} acc={va_acc:.3f} │ " + f"LR={scheduler.get_last_lr()[0]:.2e} [{elapsed:.1f}s]" + ) + + # ── Checkpoint ──────────────────────────────────────── + if va_loss < best_val_loss: + best_val_loss = va_loss + patience_counter = 0 + ckpt_path = out / "best_model.pth" + torch.save({ + "epoch": epoch, + "model_state": model.state_dict(), + "optim_state": optimizer.state_dict(), + "val_loss": va_loss, + "val_acc": va_acc, + "args": vars(args), + }, ckpt_path) + print(f" ✅ Best model saved → {ckpt_path}") + else: + patience_counter += 1 + if patience_counter >= args.patience: + print(f"\n⏹ Early stopping after {epoch} epochs (patience={args.patience})") + break + + # Refresh pairs each epoch (random re-pairing) + if hasattr(train_loader.dataset, "on_epoch_end"): + train_loader.dataset.on_epoch_end() + + # ── Plot training curves ─────────────────────────────────── + _plot_history(history, out / "training_curves.png") + print(f"\n📊 Training curves saved → {out / 'training_curves.png'}") + print(f"🏆 Best val loss: {best_val_loss:.4f}") + + +def _plot_history(history, save_path): + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + epochs = range(1, len(history["train_loss"]) + 1) + + ax1.plot(epochs, history["train_loss"], label="Train") + ax1.plot(epochs, history["val_loss"], label="Val") + ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(True) + + ax2.plot(epochs, history["train_acc"], label="Train") + ax2.plot(epochs, history["val_acc"], label="Val") + ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.set_ylim(0, 1) + ax2.legend(); ax2.grid(True) + + plt.tight_layout() + plt.savefig(save_path, dpi=150) + plt.close() + + +# ───────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────── +if __name__ == "__main__": + p = argparse.ArgumentParser(description="Train Siamese Neural Network") + p.add_argument("--train_dir", default=None, help="Path to training data root") + p.add_argument("--val_dir", default=None, help="Path to validation data root") + p.add_argument("--output_dir", default="checkpoints") + p.add_argument("--img_size", type=int, default=224) + p.add_argument("--batch_size", type=int, default=32) + p.add_argument("--epochs", type=int, default=20) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--embedding_dim", type=int, default=128) + p.add_argument("--pairs_per_epoch", type=int, default=5000) + p.add_argument("--num_workers", type=int, default=4) + p.add_argument("--patience", type=int, default=5, + help="Early stopping patience (epochs)") + + args = p.parse_args() + train(args)