Add dataset, model, training, inference, and utility scripts for Siamese Network implementation

- Implemented dataset and augmentation pipeline in `dataset.py` for handling image pairs and heavy augmentations.
- Created a Siamese Network architecture in `model.py` with shared weights and contrastive loss functions.
- Developed training script in `train.py` to facilitate model training with validation and early stopping.
- Added inference capabilities in `inference.py` for comparing images and finding matches.
- Included a perceptual hashing utility in `p_hash.py` for quick image similarity checks.
AMB_DEV
bahawal.baloch 2026-05-07 17:45:33 +05:00
parent 45b2f6e98c
commit c66d3b1029
5 changed files with 805 additions and 0 deletions

275
dataset.py Normal file
View File

@ -0,0 +1,275 @@
"""
Dataset & Augmentation Pipeline
================================
Handles:
Random pairing (same-class label=1, diff-class label=0)
Heavy augmentation for noise robustness:
- Random crops + resize simulates cropping
- Random resize simulates resolution differences
- Gaussian blur + noise simulates sensor noise
- Color jitter lighting variation
- Horizontal flip, rotation
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, Callable
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFilter, ImageEnhance
import numpy as np
# ─────────────────────────────────────────────────────────────
# Custom Transform: Add Gaussian pixel noise
# ─────────────────────────────────────────────────────────────
class AddGaussianNoise:
def __init__(self, std: float = 0.05):
self.std = std
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
noise = torch.randn_like(tensor) * self.std
return torch.clamp(tensor + noise, 0.0, 1.0)
# ─────────────────────────────────────────────────────────────
# Augmentation Pipelines
# ─────────────────────────────────────────────────────────────
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def get_train_transform(img_size: int = 224) -> transforms.Compose:
"""Heavy augmentation for training — robust to crops, resolution, noise."""
return transforms.Compose([
# Simulate resolution difference: random resize before crop
transforms.Resize((int(img_size * random.uniform(0.8, 1.4)),
int(img_size * random.uniform(0.8, 1.4)))
if False else img_size), # placeholder; see note below
transforms.RandomResizedCrop(
img_size,
scale=(0.5, 1.0), # aggressive cropping (50 100 % of image)
ratio=(0.75, 1.33),
),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
),
transforms.RandomGrayscale(p=0.1),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3),
transforms.ToTensor(),
AddGaussianNoise(std=0.03),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
def get_val_transform(img_size: int = 224) -> transforms.Compose:
"""Deterministic transform for validation/inference."""
return transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# ─────────────────────────────────────────────────────────────
# Siamese Pair Dataset
# ─────────────────────────────────────────────────────────────
class SiamesePairDataset(Dataset):
"""
Folder layout expected:
root/
class_A/ img1.jpg img2.jpg ...
class_B/ img1.jpg ...
Each __getitem__ returns (img1_tensor, img2_tensor, label)
label = 1 same class
label = 0 different class
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pairs_per_epoch: int = 5000,
same_ratio: float = 0.5,
):
self.root = Path(root)
self.transform = transform
self.pairs_per_epoch = pairs_per_epoch
self.same_ratio = same_ratio
# Build class → [image paths] mapping
self.class_to_imgs: dict[str, list[Path]] = {}
for cls_dir in sorted(self.root.iterdir()):
if cls_dir.is_dir():
imgs = [
p for p in cls_dir.iterdir()
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
]
if imgs:
self.class_to_imgs[cls_dir.name] = imgs
self.classes = list(self.class_to_imgs.keys())
assert len(self.classes) >= 2, "Need at least 2 classes."
# Pre-generate pairs for this epoch
self._pairs = self._generate_pairs()
def _generate_pairs(self):
pairs = []
n_same = int(self.pairs_per_epoch * self.same_ratio)
n_diff = self.pairs_per_epoch - n_same
# Same-class pairs (label = 1)
for _ in range(n_same):
cls = random.choice(self.classes)
imgs = self.class_to_imgs[cls]
if len(imgs) < 2:
a = b = imgs[0]
else:
a, b = random.sample(imgs, 2)
pairs.append((a, b, 1))
# Different-class pairs (label = 0)
for _ in range(n_diff):
c1, c2 = random.sample(self.classes, 2)
a = random.choice(self.class_to_imgs[c1])
b = random.choice(self.class_to_imgs[c2])
pairs.append((a, b, 0))
random.shuffle(pairs)
return pairs
def on_epoch_end(self):
"""Call between epochs to refresh random pairs."""
self._pairs = self._generate_pairs()
def __len__(self):
return len(self._pairs)
def __getitem__(self, idx):
path1, path2, label = self._pairs[idx]
img1 = Image.open(path1).convert("RGB")
img2 = Image.open(path2).convert("RGB")
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2, torch.tensor(label, dtype=torch.float32)
# ─────────────────────────────────────────────────────────────
# Demo: Synthetic dataset (MNIST-style circles/squares)
# ─────────────────────────────────────────────────────────────
class SyntheticShapeDataset(Dataset):
"""
Generates synthetic image pairs on-the-fly.
Classes: circle, square, triangle (3 classes)
Useful for quick testing without real data.
"""
SHAPES = ["circle", "square", "triangle"]
def __init__(
self,
n_pairs: int = 2000,
img_size: int = 224,
same_ratio: float = 0.5,
augment: bool = True,
):
self.n_pairs = n_pairs
self.img_size = img_size
self.same_ratio = same_ratio
self.augment = augment
self.transform = get_train_transform(img_size) if augment else get_val_transform(img_size)
self._pairs = self._generate_pairs()
def _generate_pairs(self):
pairs = []
for _ in range(self.n_pairs):
if random.random() < self.same_ratio:
s = random.choice(self.SHAPES)
pairs.append((s, s, 1))
else:
s1, s2 = random.sample(self.SHAPES, 2)
pairs.append((s1, s2, 0))
return pairs
def _draw_shape(self, shape: str) -> Image.Image:
from PIL import ImageDraw
sz = self.img_size
img = Image.new("RGB", (sz, sz), color=(
random.randint(200, 255),
random.randint(200, 255),
random.randint(200, 255),
))
draw = ImageDraw.Draw(img)
margin = sz // 6
x0 = random.randint(margin, sz // 3)
y0 = random.randint(margin, sz // 3)
x1 = random.randint(sz * 2 // 3, sz - margin)
y1 = random.randint(sz * 2 // 3, sz - margin)
color = (random.randint(0, 100), random.randint(0, 100), random.randint(50, 150))
if shape == "circle":
draw.ellipse([x0, y0, x1, y1], fill=color)
elif shape == "square":
draw.rectangle([x0, y0, x1, y1], fill=color)
else: # triangle
draw.polygon([(sz // 2, y0), (x0, y1), (x1, y1)], fill=color)
return img
def __len__(self):
return self.n_pairs
def __getitem__(self, idx):
s1, s2, label = self._pairs[idx]
img1 = self.transform(self._draw_shape(s1))
img2 = self.transform(self._draw_shape(s2))
return img1, img2, torch.tensor(label, dtype=torch.float32)
# ─────────────────────────────────────────────────────────────
# DataLoader factory
# ─────────────────────────────────────────────────────────────
def build_dataloaders(
train_root: Optional[str] = None,
val_root: Optional[str] = None,
img_size: int = 224,
batch_size: int = 32,
num_workers: int = 4,
pairs_per_epoch: int = 5000,
) -> Tuple[DataLoader, DataLoader]:
if train_root and Path(train_root).exists():
train_ds = SiamesePairDataset(
train_root,
transform=get_train_transform(img_size),
pairs_per_epoch=pairs_per_epoch,
)
val_ds = SiamesePairDataset(
val_root or train_root,
transform=get_val_transform(img_size),
pairs_per_epoch=pairs_per_epoch // 5,
)
else:
print("⚠️ No dataset root provided — using synthetic shapes for demo.")
train_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch, img_size=img_size, augment=True)
val_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch // 5, img_size=img_size, augment=False)
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True,
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
return train_loader, val_loader

202
inference.py Normal file
View File

@ -0,0 +1,202 @@
"""
Inference Module
================
Load a trained Siamese Network and compare images.
Usage:
python inference.py --img1 a.jpg --img2 b.jpg --checkpoint checkpoints/best_model.pth
python inference.py --query logo.png --gallery_dir /images/ # find best match
"""
import argparse
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from model import SiameseNet
from dataset import get_val_transform, IMAGENET_MEAN, IMAGENET_STD
# ─────────────────────────────────────────────────────────────
# Predictor class
# ─────────────────────────────────────────────────────────────
class SiamesePredictor:
"""
High-level inference wrapper.
"""
def __init__(
self,
checkpoint_path: Optional[str] = None,
embedding_dim: int = 128,
device: Optional[str] = None,
threshold: float = 0.6,
):
self.device = torch.device(device or (
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
))
self.threshold = threshold
self.transform = get_val_transform(224)
# Build model
self.model = SiameseNet(embedding_dim=embedding_dim, pretrained=False)
self.model.to(self.device)
if checkpoint_path and Path(checkpoint_path).exists():
ckpt = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(ckpt["model_state"])
print(f"✅ Loaded checkpoint from {checkpoint_path}")
print(f" Trained for {ckpt.get('epoch', '?')} epochs | "
f"val_acc = {ckpt.get('val_acc', 0):.3f}")
else:
print("⚠️ No checkpoint — running with random weights (demo only).")
self.model.eval()
def _load(self, path_or_pil) -> torch.Tensor:
if isinstance(path_or_pil, (str, Path)):
img = Image.open(path_or_pil).convert("RGB")
else:
img = path_or_pil.convert("RGB")
return self.transform(img).unsqueeze(0).to(self.device)
@torch.no_grad()
def compare(self, img1, img2) -> dict:
"""
Compare two images.
Returns dict with score, match, and euclidean distance.
"""
t1 = self._load(img1)
t2 = self._load(img2)
score, emb1, emb2 = self.model(t1, t2)
s = score.item()
dist = F.pairwise_distance(emb1, emb2).item()
return {
"similarity_score": round(s, 4),
"euclidean_dist": round(dist, 4),
"match": s >= self.threshold,
"confidence": "HIGH" if abs(s - 0.5) > 0.3 else
"MED" if abs(s - 0.5) > 0.15 else "LOW",
}
@torch.no_grad()
def embed(self, img) -> np.ndarray:
"""Get embedding vector for an image."""
t = self._load(img)
return self.model.get_embedding(t).cpu().numpy()[0]
@torch.no_grad()
def find_best_match(
self, query_img, gallery: List, top_k: int = 5
) -> List[Tuple[float, int]]:
"""
Find top-k most similar images from a gallery list.
gallery: list of image paths or PIL images.
Returns: list of (score, index) sorted descending.
"""
q_emb = torch.from_numpy(self.embed(query_img)).to(self.device)
scores = []
for idx, img in enumerate(gallery):
g_emb = torch.from_numpy(self.embed(img)).to(self.device)
cos = F.cosine_similarity(q_emb.unsqueeze(0), g_emb.unsqueeze(0)).item()
s = (cos + 1.0) / 2.0
scores.append((s, idx))
scores.sort(reverse=True)
return scores[:top_k]
# ─────────────────────────────────────────────────────────────
# Visualisation helper
# ─────────────────────────────────────────────────────────────
def visualise_comparison(img1_path, img2_path, result: dict, save_path="comparison.png"):
img1 = Image.open(img1_path).convert("RGB")
img2 = Image.open(img2_path).convert("RGB")
fig, axes = plt.subplots(1, 3, figsize=(14, 5))
fig.patch.set_facecolor("#0f0f1a")
for ax in axes:
ax.set_facecolor("#0f0f1a")
ax.axis("off")
axes[0].imshow(img1)
axes[0].set_title("Image 1", color="white", fontsize=13, pad=10)
axes[2].imshow(img2)
axes[2].set_title("Image 2", color="white", fontsize=13, pad=10)
# Middle panel: result
color = "#00e676" if result["match"] else "#ff1744"
symbol = "✓ MATCH" if result["match"] else "✗ NO MATCH"
axes[1].set_xlim(0, 1); axes[1].set_ylim(0, 1)
axes[1].add_patch(mpatches.FancyBboxPatch(
(0.1, 0.2), 0.8, 0.6, boxstyle="round,pad=0.05",
linewidth=3, edgecolor=color, facecolor="#1a1a2e"
))
axes[1].text(0.5, 0.75, symbol, ha="center", va="center", fontsize=20,
fontweight="bold", color=color)
axes[1].text(0.5, 0.55, f"Score: {result['similarity_score']:.3f}",
ha="center", va="center", fontsize=14, color="white")
axes[1].text(0.5, 0.40, f"Dist: {result['euclidean_dist']:.3f}",
ha="center", va="center", fontsize=11, color="#aaaaaa")
axes[1].text(0.5, 0.28, f"Conf: {result['confidence']}",
ha="center", va="center", fontsize=11, color="#aaaaaa")
plt.tight_layout(pad=2)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"📸 Visualisation saved → {save_path}")
# ─────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Siamese Network Inference")
p.add_argument("--img1", required=True, help="Path to image 1")
p.add_argument("--img2", default=None, help="Path to image 2 (comparison mode)")
p.add_argument("--gallery_dir", default=None, help="Directory to search (gallery mode)")
p.add_argument("--checkpoint", default="checkpoints/best_model.pth")
p.add_argument("--threshold", type=float, default=0.6, help="Match threshold")
p.add_argument("--top_k", type=int, default=5)
p.add_argument("--output", default="comparison.png")
args = p.parse_args()
predictor = SiamesePredictor(
checkpoint_path=args.checkpoint,
threshold=args.threshold,
)
if args.img2:
# ── Compare two images ────────────────────────────────
result = predictor.compare(args.img1, args.img2)
print("\n" + "" * 45)
print(f" Similarity score : {result['similarity_score']}")
print(f" Euclidean dist : {result['euclidean_dist']}")
print(f" Match : {'YES ✓' if result['match'] else 'NO ✗'}")
print(f" Confidence : {result['confidence']}")
print("" * 45)
visualise_comparison(args.img1, args.img2, result, args.output)
elif args.gallery_dir:
# ── Gallery search ────────────────────────────────────
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
gallery = [p for p in Path(args.gallery_dir).iterdir() if p.suffix.lower() in exts]
matches = predictor.find_best_match(args.img1, gallery, top_k=args.top_k)
print(f"\nTop-{args.top_k} matches for {args.img1}:")
for rank, (score, idx) in enumerate(matches, 1):
verdict = "✓ MATCH" if score >= args.threshold else " diff "
print(f" {rank}. {verdict} score={score:.4f} {gallery[idx]}")
else:
print("Provide --img2 or --gallery_dir")

113
model.py Normal file
View File

@ -0,0 +1,113 @@
"""
Siamese Neural Network for Image Similarity Matching
======================================================
Robust to: cropping, resolution differences, noise, rotation, color jitter.
Architecture:
- Shared CNN encoder (ResNet-18 backbone, pretrained)
- L2-normalized embedding head
- Contrastive Loss / Binary Cross-Entropy with cosine similarity
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# ─────────────────────────────────────────────────────────────
# Embedding Network (shared weights for both branches)
# ─────────────────────────────────────────────────────────────
class EmbeddingNet(nn.Module):
"""
ResNet-18 backbone with a custom projection head.
Outputs a 128-dim L2-normalised embedding.
"""
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
super().__init__()
# Load pretrained ResNet-18
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
backbone = models.resnet18(weights=weights)
# Remove the final FC layer; keep everything up to avgpool
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
# Projection head: 512 → 256 → embedding_dim
self.projector = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, embedding_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, 3, H, W)
feats = self.feature_extractor(x) # (B, 512, 1, 1)
feats = feats.view(feats.size(0), -1) # (B, 512)
emb = self.projector(feats) # (B, embedding_dim)
return F.normalize(emb, p=2, dim=1) # L2-normalised
# ─────────────────────────────────────────────────────────────
# Siamese Network
# ─────────────────────────────────────────────────────────────
class SiameseNet(nn.Module):
"""
Takes two images, returns:
similarity score [0, 1] (1 = same, 0 = different)
both L2-normalised embeddings
"""
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
super().__init__()
self.encoder = EmbeddingNet(embedding_dim, pretrained)
def forward(self, img1: torch.Tensor, img2: torch.Tensor):
emb1 = self.encoder(img1)
emb2 = self.encoder(img2)
# Cosine similarity → [1, 1], rescale to [0, 1]
cos_sim = F.cosine_similarity(emb1, emb2, dim=1)
score = (cos_sim + 1.0) / 2.0 # (B,)
return score, emb1, emb2
def get_embedding(self, img: torch.Tensor) -> torch.Tensor:
"""Get embedding for a single image (useful at inference time)."""
return self.encoder(img)
# ─────────────────────────────────────────────────────────────
# Loss Functions
# ─────────────────────────────────────────────────────────────
class ContrastiveLoss(nn.Module):
"""
Contrastive Loss (LeCun 2005).
label = 1 same pair (pull embeddings together)
label = 0 diff pair (push embeddings apart, beyond margin)
"""
def __init__(self, margin: float = 1.0):
super().__init__()
self.margin = margin
def forward(self, emb1, emb2, label):
dist = F.pairwise_distance(emb1, emb2) # Euclidean distance
same_loss = label * dist.pow(2)
diff_loss = (1 - label) * F.relu(self.margin - dist).pow(2)
return (same_loss + diff_loss).mean()
class BCECosineLoss(nn.Module):
"""
Binary Cross-Entropy on cosine-similarity score.
Simple and very effective for this task.
"""
def __init__(self):
super().__init__()
self.bce = nn.BCELoss()
def forward(self, score, label):
return self.bce(score, label.float())

20
p_hash.py Normal file
View File

@ -0,0 +1,20 @@
from PIL import Image
import imagehash
import os
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
img1 = imagehash.phash(Image.open("u1.png"))
img2 = imagehash.phash(Image.open("u2.png"))
def check_image(img1, img2):
print(img1)
print(img2)
distance = img1 - img2
print(distance)
if distance > 10: # typical threshold ~5-12 depending on use case
print("→ pHash says different → Early exit")
else:
print("→ pHash passed → Proceeding to deep embeddings")
check_image(img1, img2)

195
train.py Normal file
View File

@ -0,0 +1,195 @@
"""
Training Script
===============
Usage:
python train.py # synthetic demo
python train.py --train_dir data/train --val_dir data/val
python train.py --train_dir data/train --epochs 30 --batch_size 64
"""
import argparse
import os
import time
from pathlib import Path
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from model import SiameseNet, ContrastiveLoss, BCECosineLoss
from dataset import build_dataloaders
from dotenv import load_dotenv
load_dotenv()
# ─────────────────────────────────────────────────────────────
# Metrics
# ─────────────────────────────────────────────────────────────
def compute_accuracy(scores: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
preds = (scores >= threshold).float()
correct = (preds == labels).sum().item()
return correct / len(labels)
# ─────────────────────────────────────────────────────────────
# One epoch
# ─────────────────────────────────────────────────────────────
def run_epoch(model, loader, criterion, optimizer, device, train: bool):
model.train() if train else model.eval()
total_loss, total_acc, n = 0.0, 0.0, 0
ctx = torch.enable_grad() if train else torch.no_grad()
with ctx:
for img1, img2, labels in tqdm(loader, desc="train" if train else "val ", leave=False):
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
scores, emb1, emb2 = model(img1, img2)
# Use BCECosineLoss by default; swap to ContrastiveLoss if you prefer
loss = criterion(scores, labels)
if train:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
acc = compute_accuracy(scores.detach().cpu(), labels.cpu())
bs = img1.size(0)
total_loss += loss.item() * bs
total_acc += acc * bs
n += bs
return total_loss / n, total_acc / n
# ─────────────────────────────────────────────────────────────
# Training loop
# ─────────────────────────────────────────────────────────────
def train(args):
device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
print(f"🖥️ Device: {device}")
args.train_dir=os.getenv("DATASET")
args.val_dir=os.getenv("VAL_DATASET")
# ── Data ──────────────────────────────────────────────────
train_loader, val_loader = build_dataloaders(
train_root = args.train_dir,
val_root = args.val_dir,
img_size = args.img_size,
batch_size = args.batch_size,
num_workers = args.num_workers,
pairs_per_epoch = args.pairs_per_epoch,
)
print(f"📦 Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
# ── Model ─────────────────────────────────────────────────
model = SiameseNet(embedding_dim=args.embedding_dim, pretrained=True).to(device)
criterion = BCECosineLoss()
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
# ── Output dir ────────────────────────────────────────────
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
best_val_loss = float("inf")
patience_counter = 0
print("\n🚀 Training started\n")
for epoch in range(1, args.epochs + 1):
t0 = time.time()
tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, train=True)
va_loss, va_acc = run_epoch(model, val_loader, criterion, None, device, train=False)
scheduler.step()
history["train_loss"].append(tr_loss)
history["val_loss"].append(va_loss)
history["train_acc"].append(tr_acc)
history["val_acc"].append(va_acc)
elapsed = time.time() - t0
print(
f"Epoch {epoch:3d}/{args.epochs}"
f"Train loss={tr_loss:.4f} acc={tr_acc:.3f}"
f"Val loss={va_loss:.4f} acc={va_acc:.3f}"
f"LR={scheduler.get_last_lr()[0]:.2e} [{elapsed:.1f}s]"
)
# ── Checkpoint ────────────────────────────────────────
if va_loss < best_val_loss:
best_val_loss = va_loss
patience_counter = 0
ckpt_path = out / "best_model.pth"
torch.save({
"epoch": epoch,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict(),
"val_loss": va_loss,
"val_acc": va_acc,
"args": vars(args),
}, ckpt_path)
print(f" ✅ Best model saved → {ckpt_path}")
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f"\n⏹ Early stopping after {epoch} epochs (patience={args.patience})")
break
# Refresh pairs each epoch (random re-pairing)
if hasattr(train_loader.dataset, "on_epoch_end"):
train_loader.dataset.on_epoch_end()
# ── Plot training curves ───────────────────────────────────
_plot_history(history, out / "training_curves.png")
print(f"\n📊 Training curves saved → {out / 'training_curves.png'}")
print(f"🏆 Best val loss: {best_val_loss:.4f}")
def _plot_history(history, save_path):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
epochs = range(1, len(history["train_loss"]) + 1)
ax1.plot(epochs, history["train_loss"], label="Train")
ax1.plot(epochs, history["val_loss"], label="Val")
ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(True)
ax2.plot(epochs, history["train_acc"], label="Train")
ax2.plot(epochs, history["val_acc"], label="Val")
ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.set_ylim(0, 1)
ax2.legend(); ax2.grid(True)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
# ─────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Train Siamese Neural Network")
p.add_argument("--train_dir", default=None, help="Path to training data root")
p.add_argument("--val_dir", default=None, help="Path to validation data root")
p.add_argument("--output_dir", default="checkpoints")
p.add_argument("--img_size", type=int, default=224)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--epochs", type=int, default=20)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--embedding_dim", type=int, default=128)
p.add_argument("--pairs_per_epoch", type=int, default=5000)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--patience", type=int, default=5,
help="Early stopping patience (epochs)")
args = p.parse_args()
train(args)