Add dataset, model, training, inference, and utility scripts for Siamese Network implementation
- Implemented dataset and augmentation pipeline in `dataset.py` for handling image pairs and heavy augmentations. - Created a Siamese Network architecture in `model.py` with shared weights and contrastive loss functions. - Developed training script in `train.py` to facilitate model training with validation and early stopping. - Added inference capabilities in `inference.py` for comparing images and finding matches. - Included a perceptual hashing utility in `p_hash.py` for quick image similarity checks.AMB_DEV
parent
45b2f6e98c
commit
c66d3b1029
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Dataset & Augmentation Pipeline
|
||||
================================
|
||||
Handles:
|
||||
• Random pairing (same-class → label=1, diff-class → label=0)
|
||||
• Heavy augmentation for noise robustness:
|
||||
- Random crops + resize → simulates cropping
|
||||
- Random resize → simulates resolution differences
|
||||
- Gaussian blur + noise → simulates sensor noise
|
||||
- Color jitter → lighting variation
|
||||
- Horizontal flip, rotation
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image, ImageFilter, ImageEnhance
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Custom Transform: Add Gaussian pixel noise
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class AddGaussianNoise:
|
||||
def __init__(self, std: float = 0.05):
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
noise = torch.randn_like(tensor) * self.std
|
||||
return torch.clamp(tensor + noise, 0.0, 1.0)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Augmentation Pipelines
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
||||
IMAGENET_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
def get_train_transform(img_size: int = 224) -> transforms.Compose:
|
||||
"""Heavy augmentation for training — robust to crops, resolution, noise."""
|
||||
return transforms.Compose([
|
||||
# Simulate resolution difference: random resize before crop
|
||||
transforms.Resize((int(img_size * random.uniform(0.8, 1.4)),
|
||||
int(img_size * random.uniform(0.8, 1.4)))
|
||||
if False else img_size), # placeholder; see note below
|
||||
transforms.RandomResizedCrop(
|
||||
img_size,
|
||||
scale=(0.5, 1.0), # aggressive cropping (50 – 100 % of image)
|
||||
ratio=(0.75, 1.33),
|
||||
),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1,
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.1),
|
||||
transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3),
|
||||
transforms.ToTensor(),
|
||||
AddGaussianNoise(std=0.03),
|
||||
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||||
])
|
||||
|
||||
|
||||
def get_val_transform(img_size: int = 224) -> transforms.Compose:
|
||||
"""Deterministic transform for validation/inference."""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_size, img_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
||||
])
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Siamese Pair Dataset
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class SiamesePairDataset(Dataset):
|
||||
"""
|
||||
Folder layout expected:
|
||||
root/
|
||||
class_A/ img1.jpg img2.jpg ...
|
||||
class_B/ img1.jpg ...
|
||||
|
||||
Each __getitem__ returns (img1_tensor, img2_tensor, label)
|
||||
label = 1 → same class
|
||||
label = 0 → different class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transform: Optional[Callable] = None,
|
||||
pairs_per_epoch: int = 5000,
|
||||
same_ratio: float = 0.5,
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.transform = transform
|
||||
self.pairs_per_epoch = pairs_per_epoch
|
||||
self.same_ratio = same_ratio
|
||||
|
||||
# Build class → [image paths] mapping
|
||||
self.class_to_imgs: dict[str, list[Path]] = {}
|
||||
for cls_dir in sorted(self.root.iterdir()):
|
||||
if cls_dir.is_dir():
|
||||
imgs = [
|
||||
p for p in cls_dir.iterdir()
|
||||
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||||
]
|
||||
if imgs:
|
||||
self.class_to_imgs[cls_dir.name] = imgs
|
||||
|
||||
self.classes = list(self.class_to_imgs.keys())
|
||||
assert len(self.classes) >= 2, "Need at least 2 classes."
|
||||
|
||||
# Pre-generate pairs for this epoch
|
||||
self._pairs = self._generate_pairs()
|
||||
|
||||
def _generate_pairs(self):
|
||||
pairs = []
|
||||
n_same = int(self.pairs_per_epoch * self.same_ratio)
|
||||
n_diff = self.pairs_per_epoch - n_same
|
||||
|
||||
# Same-class pairs (label = 1)
|
||||
for _ in range(n_same):
|
||||
cls = random.choice(self.classes)
|
||||
imgs = self.class_to_imgs[cls]
|
||||
if len(imgs) < 2:
|
||||
a = b = imgs[0]
|
||||
else:
|
||||
a, b = random.sample(imgs, 2)
|
||||
pairs.append((a, b, 1))
|
||||
|
||||
# Different-class pairs (label = 0)
|
||||
for _ in range(n_diff):
|
||||
c1, c2 = random.sample(self.classes, 2)
|
||||
a = random.choice(self.class_to_imgs[c1])
|
||||
b = random.choice(self.class_to_imgs[c2])
|
||||
pairs.append((a, b, 0))
|
||||
|
||||
random.shuffle(pairs)
|
||||
return pairs
|
||||
|
||||
def on_epoch_end(self):
|
||||
"""Call between epochs to refresh random pairs."""
|
||||
self._pairs = self._generate_pairs()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path1, path2, label = self._pairs[idx]
|
||||
img1 = Image.open(path1).convert("RGB")
|
||||
img2 = Image.open(path2).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
img1 = self.transform(img1)
|
||||
img2 = self.transform(img2)
|
||||
|
||||
return img1, img2, torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Demo: Synthetic dataset (MNIST-style circles/squares)
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class SyntheticShapeDataset(Dataset):
|
||||
"""
|
||||
Generates synthetic image pairs on-the-fly.
|
||||
Classes: circle, square, triangle (3 classes)
|
||||
Useful for quick testing without real data.
|
||||
"""
|
||||
|
||||
SHAPES = ["circle", "square", "triangle"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_pairs: int = 2000,
|
||||
img_size: int = 224,
|
||||
same_ratio: float = 0.5,
|
||||
augment: bool = True,
|
||||
):
|
||||
self.n_pairs = n_pairs
|
||||
self.img_size = img_size
|
||||
self.same_ratio = same_ratio
|
||||
self.augment = augment
|
||||
self.transform = get_train_transform(img_size) if augment else get_val_transform(img_size)
|
||||
self._pairs = self._generate_pairs()
|
||||
|
||||
def _generate_pairs(self):
|
||||
pairs = []
|
||||
for _ in range(self.n_pairs):
|
||||
if random.random() < self.same_ratio:
|
||||
s = random.choice(self.SHAPES)
|
||||
pairs.append((s, s, 1))
|
||||
else:
|
||||
s1, s2 = random.sample(self.SHAPES, 2)
|
||||
pairs.append((s1, s2, 0))
|
||||
return pairs
|
||||
|
||||
def _draw_shape(self, shape: str) -> Image.Image:
|
||||
from PIL import ImageDraw
|
||||
sz = self.img_size
|
||||
img = Image.new("RGB", (sz, sz), color=(
|
||||
random.randint(200, 255),
|
||||
random.randint(200, 255),
|
||||
random.randint(200, 255),
|
||||
))
|
||||
draw = ImageDraw.Draw(img)
|
||||
margin = sz // 6
|
||||
x0 = random.randint(margin, sz // 3)
|
||||
y0 = random.randint(margin, sz // 3)
|
||||
x1 = random.randint(sz * 2 // 3, sz - margin)
|
||||
y1 = random.randint(sz * 2 // 3, sz - margin)
|
||||
color = (random.randint(0, 100), random.randint(0, 100), random.randint(50, 150))
|
||||
if shape == "circle":
|
||||
draw.ellipse([x0, y0, x1, y1], fill=color)
|
||||
elif shape == "square":
|
||||
draw.rectangle([x0, y0, x1, y1], fill=color)
|
||||
else: # triangle
|
||||
draw.polygon([(sz // 2, y0), (x0, y1), (x1, y1)], fill=color)
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return self.n_pairs
|
||||
|
||||
def __getitem__(self, idx):
|
||||
s1, s2, label = self._pairs[idx]
|
||||
img1 = self.transform(self._draw_shape(s1))
|
||||
img2 = self.transform(self._draw_shape(s2))
|
||||
return img1, img2, torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# DataLoader factory
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
def build_dataloaders(
|
||||
train_root: Optional[str] = None,
|
||||
val_root: Optional[str] = None,
|
||||
img_size: int = 224,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
pairs_per_epoch: int = 5000,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
|
||||
if train_root and Path(train_root).exists():
|
||||
train_ds = SiamesePairDataset(
|
||||
train_root,
|
||||
transform=get_train_transform(img_size),
|
||||
pairs_per_epoch=pairs_per_epoch,
|
||||
)
|
||||
val_ds = SiamesePairDataset(
|
||||
val_root or train_root,
|
||||
transform=get_val_transform(img_size),
|
||||
pairs_per_epoch=pairs_per_epoch // 5,
|
||||
)
|
||||
else:
|
||||
print("⚠️ No dataset root provided — using synthetic shapes for demo.")
|
||||
train_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch, img_size=img_size, augment=True)
|
||||
val_ds = SyntheticShapeDataset(n_pairs=pairs_per_epoch // 5, img_size=img_size, augment=False)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=batch_size, shuffle=True,
|
||||
num_workers=num_workers, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds, batch_size=batch_size, shuffle=False,
|
||||
num_workers=num_workers, pin_memory=True,
|
||||
)
|
||||
return train_loader, val_loader
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
Inference Module
|
||||
================
|
||||
Load a trained Siamese Network and compare images.
|
||||
|
||||
Usage:
|
||||
python inference.py --img1 a.jpg --img2 b.jpg --checkpoint checkpoints/best_model.pth
|
||||
python inference.py --query logo.png --gallery_dir /images/ # find best match
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
|
||||
from model import SiameseNet
|
||||
from dataset import get_val_transform, IMAGENET_MEAN, IMAGENET_STD
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Predictor class
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class SiamesePredictor:
|
||||
"""
|
||||
High-level inference wrapper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
embedding_dim: int = 128,
|
||||
device: Optional[str] = None,
|
||||
threshold: float = 0.6,
|
||||
):
|
||||
self.device = torch.device(device or (
|
||||
"cuda" if torch.cuda.is_available() else
|
||||
"mps" if torch.backends.mps.is_available() else
|
||||
"cpu"
|
||||
))
|
||||
self.threshold = threshold
|
||||
self.transform = get_val_transform(224)
|
||||
|
||||
# Build model
|
||||
self.model = SiameseNet(embedding_dim=embedding_dim, pretrained=False)
|
||||
self.model.to(self.device)
|
||||
|
||||
if checkpoint_path and Path(checkpoint_path).exists():
|
||||
ckpt = torch.load(checkpoint_path, map_location=self.device)
|
||||
self.model.load_state_dict(ckpt["model_state"])
|
||||
print(f"✅ Loaded checkpoint from {checkpoint_path}")
|
||||
print(f" Trained for {ckpt.get('epoch', '?')} epochs | "
|
||||
f"val_acc = {ckpt.get('val_acc', 0):.3f}")
|
||||
else:
|
||||
print("⚠️ No checkpoint — running with random weights (demo only).")
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def _load(self, path_or_pil) -> torch.Tensor:
|
||||
if isinstance(path_or_pil, (str, Path)):
|
||||
img = Image.open(path_or_pil).convert("RGB")
|
||||
else:
|
||||
img = path_or_pil.convert("RGB")
|
||||
return self.transform(img).unsqueeze(0).to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def compare(self, img1, img2) -> dict:
|
||||
"""
|
||||
Compare two images.
|
||||
Returns dict with score, match, and euclidean distance.
|
||||
"""
|
||||
t1 = self._load(img1)
|
||||
t2 = self._load(img2)
|
||||
score, emb1, emb2 = self.model(t1, t2)
|
||||
s = score.item()
|
||||
dist = F.pairwise_distance(emb1, emb2).item()
|
||||
return {
|
||||
"similarity_score": round(s, 4),
|
||||
"euclidean_dist": round(dist, 4),
|
||||
"match": s >= self.threshold,
|
||||
"confidence": "HIGH" if abs(s - 0.5) > 0.3 else
|
||||
"MED" if abs(s - 0.5) > 0.15 else "LOW",
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def embed(self, img) -> np.ndarray:
|
||||
"""Get embedding vector for an image."""
|
||||
t = self._load(img)
|
||||
return self.model.get_embedding(t).cpu().numpy()[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def find_best_match(
|
||||
self, query_img, gallery: List, top_k: int = 5
|
||||
) -> List[Tuple[float, int]]:
|
||||
"""
|
||||
Find top-k most similar images from a gallery list.
|
||||
gallery: list of image paths or PIL images.
|
||||
Returns: list of (score, index) sorted descending.
|
||||
"""
|
||||
q_emb = torch.from_numpy(self.embed(query_img)).to(self.device)
|
||||
scores = []
|
||||
for idx, img in enumerate(gallery):
|
||||
g_emb = torch.from_numpy(self.embed(img)).to(self.device)
|
||||
cos = F.cosine_similarity(q_emb.unsqueeze(0), g_emb.unsqueeze(0)).item()
|
||||
s = (cos + 1.0) / 2.0
|
||||
scores.append((s, idx))
|
||||
scores.sort(reverse=True)
|
||||
return scores[:top_k]
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Visualisation helper
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
def visualise_comparison(img1_path, img2_path, result: dict, save_path="comparison.png"):
|
||||
img1 = Image.open(img1_path).convert("RGB")
|
||||
img2 = Image.open(img2_path).convert("RGB")
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(14, 5))
|
||||
fig.patch.set_facecolor("#0f0f1a")
|
||||
|
||||
for ax in axes:
|
||||
ax.set_facecolor("#0f0f1a")
|
||||
ax.axis("off")
|
||||
|
||||
axes[0].imshow(img1)
|
||||
axes[0].set_title("Image 1", color="white", fontsize=13, pad=10)
|
||||
|
||||
axes[2].imshow(img2)
|
||||
axes[2].set_title("Image 2", color="white", fontsize=13, pad=10)
|
||||
|
||||
# Middle panel: result
|
||||
color = "#00e676" if result["match"] else "#ff1744"
|
||||
symbol = "✓ MATCH" if result["match"] else "✗ NO MATCH"
|
||||
axes[1].set_xlim(0, 1); axes[1].set_ylim(0, 1)
|
||||
axes[1].add_patch(mpatches.FancyBboxPatch(
|
||||
(0.1, 0.2), 0.8, 0.6, boxstyle="round,pad=0.05",
|
||||
linewidth=3, edgecolor=color, facecolor="#1a1a2e"
|
||||
))
|
||||
axes[1].text(0.5, 0.75, symbol, ha="center", va="center", fontsize=20,
|
||||
fontweight="bold", color=color)
|
||||
axes[1].text(0.5, 0.55, f"Score: {result['similarity_score']:.3f}",
|
||||
ha="center", va="center", fontsize=14, color="white")
|
||||
axes[1].text(0.5, 0.40, f"Dist: {result['euclidean_dist']:.3f}",
|
||||
ha="center", va="center", fontsize=11, color="#aaaaaa")
|
||||
axes[1].text(0.5, 0.28, f"Conf: {result['confidence']}",
|
||||
ha="center", va="center", fontsize=11, color="#aaaaaa")
|
||||
|
||||
plt.tight_layout(pad=2)
|
||||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f"📸 Visualisation saved → {save_path}")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# CLI
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(description="Siamese Network Inference")
|
||||
p.add_argument("--img1", required=True, help="Path to image 1")
|
||||
p.add_argument("--img2", default=None, help="Path to image 2 (comparison mode)")
|
||||
p.add_argument("--gallery_dir", default=None, help="Directory to search (gallery mode)")
|
||||
p.add_argument("--checkpoint", default="checkpoints/best_model.pth")
|
||||
p.add_argument("--threshold", type=float, default=0.6, help="Match threshold")
|
||||
p.add_argument("--top_k", type=int, default=5)
|
||||
p.add_argument("--output", default="comparison.png")
|
||||
args = p.parse_args()
|
||||
|
||||
predictor = SiamesePredictor(
|
||||
checkpoint_path=args.checkpoint,
|
||||
threshold=args.threshold,
|
||||
)
|
||||
|
||||
if args.img2:
|
||||
# ── Compare two images ────────────────────────────────
|
||||
result = predictor.compare(args.img1, args.img2)
|
||||
print("\n" + "═" * 45)
|
||||
print(f" Similarity score : {result['similarity_score']}")
|
||||
print(f" Euclidean dist : {result['euclidean_dist']}")
|
||||
print(f" Match : {'YES ✓' if result['match'] else 'NO ✗'}")
|
||||
print(f" Confidence : {result['confidence']}")
|
||||
print("═" * 45)
|
||||
visualise_comparison(args.img1, args.img2, result, args.output)
|
||||
|
||||
elif args.gallery_dir:
|
||||
# ── Gallery search ────────────────────────────────────
|
||||
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||||
gallery = [p for p in Path(args.gallery_dir).iterdir() if p.suffix.lower() in exts]
|
||||
matches = predictor.find_best_match(args.img1, gallery, top_k=args.top_k)
|
||||
print(f"\nTop-{args.top_k} matches for {args.img1}:")
|
||||
for rank, (score, idx) in enumerate(matches, 1):
|
||||
verdict = "✓ MATCH" if score >= args.threshold else " diff "
|
||||
print(f" {rank}. {verdict} score={score:.4f} {gallery[idx]}")
|
||||
else:
|
||||
print("Provide --img2 or --gallery_dir")
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
Siamese Neural Network for Image Similarity Matching
|
||||
======================================================
|
||||
Robust to: cropping, resolution differences, noise, rotation, color jitter.
|
||||
|
||||
Architecture:
|
||||
- Shared CNN encoder (ResNet-18 backbone, pretrained)
|
||||
- L2-normalized embedding head
|
||||
- Contrastive Loss / Binary Cross-Entropy with cosine similarity
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Embedding Network (shared weights for both branches)
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class EmbeddingNet(nn.Module):
|
||||
"""
|
||||
ResNet-18 backbone with a custom projection head.
|
||||
Outputs a 128-dim L2-normalised embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
|
||||
super().__init__()
|
||||
|
||||
# Load pretrained ResNet-18
|
||||
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
|
||||
backbone = models.resnet18(weights=weights)
|
||||
|
||||
# Remove the final FC layer; keep everything up to avgpool
|
||||
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
|
||||
|
||||
# Projection head: 512 → 256 → embedding_dim
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, embedding_dim),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, 3, H, W)
|
||||
feats = self.feature_extractor(x) # (B, 512, 1, 1)
|
||||
feats = feats.view(feats.size(0), -1) # (B, 512)
|
||||
emb = self.projector(feats) # (B, embedding_dim)
|
||||
return F.normalize(emb, p=2, dim=1) # L2-normalised
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Siamese Network
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class SiameseNet(nn.Module):
|
||||
"""
|
||||
Takes two images, returns:
|
||||
• similarity score ∈ [0, 1] (1 = same, 0 = different)
|
||||
• both L2-normalised embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
|
||||
super().__init__()
|
||||
self.encoder = EmbeddingNet(embedding_dim, pretrained)
|
||||
|
||||
def forward(self, img1: torch.Tensor, img2: torch.Tensor):
|
||||
emb1 = self.encoder(img1)
|
||||
emb2 = self.encoder(img2)
|
||||
# Cosine similarity → [−1, 1], rescale to [0, 1]
|
||||
cos_sim = F.cosine_similarity(emb1, emb2, dim=1)
|
||||
score = (cos_sim + 1.0) / 2.0 # (B,)
|
||||
return score, emb1, emb2
|
||||
|
||||
def get_embedding(self, img: torch.Tensor) -> torch.Tensor:
|
||||
"""Get embedding for a single image (useful at inference time)."""
|
||||
return self.encoder(img)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Loss Functions
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Contrastive Loss (LeCun 2005).
|
||||
label = 1 → same pair (pull embeddings together)
|
||||
label = 0 → diff pair (push embeddings apart, beyond margin)
|
||||
"""
|
||||
|
||||
def __init__(self, margin: float = 1.0):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, emb1, emb2, label):
|
||||
dist = F.pairwise_distance(emb1, emb2) # Euclidean distance
|
||||
same_loss = label * dist.pow(2)
|
||||
diff_loss = (1 - label) * F.relu(self.margin - dist).pow(2)
|
||||
return (same_loss + diff_loss).mean()
|
||||
|
||||
|
||||
class BCECosineLoss(nn.Module):
|
||||
"""
|
||||
Binary Cross-Entropy on cosine-similarity score.
|
||||
Simple and very effective for this task.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bce = nn.BCELoss()
|
||||
|
||||
def forward(self, score, label):
|
||||
return self.bce(score, label.float())
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
from PIL import Image
|
||||
import imagehash
|
||||
import os
|
||||
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
img1 = imagehash.phash(Image.open("u1.png"))
|
||||
img2 = imagehash.phash(Image.open("u2.png"))
|
||||
def check_image(img1, img2):
|
||||
print(img1)
|
||||
print(img2)
|
||||
distance = img1 - img2
|
||||
print(distance)
|
||||
if distance > 10: # typical threshold ~5-12 depending on use case
|
||||
print("→ pHash says different → Early exit")
|
||||
else:
|
||||
print("→ pHash passed → Proceeding to deep embeddings")
|
||||
|
||||
|
||||
check_image(img1, img2)
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
"""
|
||||
Training Script
|
||||
===============
|
||||
Usage:
|
||||
python train.py # synthetic demo
|
||||
python train.py --train_dir data/train --val_dir data/val
|
||||
python train.py --train_dir data/train --epochs 30 --batch_size 64
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from tqdm import tqdm
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from model import SiameseNet, ContrastiveLoss, BCECosineLoss
|
||||
from dataset import build_dataloaders
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Metrics
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
def compute_accuracy(scores: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
|
||||
preds = (scores >= threshold).float()
|
||||
correct = (preds == labels).sum().item()
|
||||
return correct / len(labels)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# One epoch
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
def run_epoch(model, loader, criterion, optimizer, device, train: bool):
|
||||
model.train() if train else model.eval()
|
||||
total_loss, total_acc, n = 0.0, 0.0, 0
|
||||
|
||||
ctx = torch.enable_grad() if train else torch.no_grad()
|
||||
with ctx:
|
||||
for img1, img2, labels in tqdm(loader, desc="train" if train else "val ", leave=False):
|
||||
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
|
||||
|
||||
scores, emb1, emb2 = model(img1, img2)
|
||||
|
||||
# Use BCECosineLoss by default; swap to ContrastiveLoss if you prefer
|
||||
loss = criterion(scores, labels)
|
||||
|
||||
if train:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
acc = compute_accuracy(scores.detach().cpu(), labels.cpu())
|
||||
bs = img1.size(0)
|
||||
total_loss += loss.item() * bs
|
||||
total_acc += acc * bs
|
||||
n += bs
|
||||
|
||||
return total_loss / n, total_acc / n
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# Training loop
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
def train(args):
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else
|
||||
"mps" if torch.backends.mps.is_available() else
|
||||
"cpu"
|
||||
)
|
||||
print(f"🖥️ Device: {device}")
|
||||
args.train_dir=os.getenv("DATASET")
|
||||
args.val_dir=os.getenv("VAL_DATASET")
|
||||
# ── Data ──────────────────────────────────────────────────
|
||||
train_loader, val_loader = build_dataloaders(
|
||||
train_root = args.train_dir,
|
||||
val_root = args.val_dir,
|
||||
img_size = args.img_size,
|
||||
batch_size = args.batch_size,
|
||||
num_workers = args.num_workers,
|
||||
pairs_per_epoch = args.pairs_per_epoch,
|
||||
)
|
||||
print(f"📦 Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
|
||||
|
||||
# ── Model ─────────────────────────────────────────────────
|
||||
model = SiameseNet(embedding_dim=args.embedding_dim, pretrained=True).to(device)
|
||||
criterion = BCECosineLoss()
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
||||
|
||||
# ── Output dir ────────────────────────────────────────────
|
||||
out = Path(args.output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
|
||||
best_val_loss = float("inf")
|
||||
patience_counter = 0
|
||||
|
||||
print("\n🚀 Training started\n")
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
t0 = time.time()
|
||||
|
||||
tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, train=True)
|
||||
va_loss, va_acc = run_epoch(model, val_loader, criterion, None, device, train=False)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
history["train_loss"].append(tr_loss)
|
||||
history["val_loss"].append(va_loss)
|
||||
history["train_acc"].append(tr_acc)
|
||||
history["val_acc"].append(va_acc)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{args.epochs} │ "
|
||||
f"Train loss={tr_loss:.4f} acc={tr_acc:.3f} │ "
|
||||
f"Val loss={va_loss:.4f} acc={va_acc:.3f} │ "
|
||||
f"LR={scheduler.get_last_lr()[0]:.2e} [{elapsed:.1f}s]"
|
||||
)
|
||||
|
||||
# ── Checkpoint ────────────────────────────────────────
|
||||
if va_loss < best_val_loss:
|
||||
best_val_loss = va_loss
|
||||
patience_counter = 0
|
||||
ckpt_path = out / "best_model.pth"
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state": model.state_dict(),
|
||||
"optim_state": optimizer.state_dict(),
|
||||
"val_loss": va_loss,
|
||||
"val_acc": va_acc,
|
||||
"args": vars(args),
|
||||
}, ckpt_path)
|
||||
print(f" ✅ Best model saved → {ckpt_path}")
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= args.patience:
|
||||
print(f"\n⏹ Early stopping after {epoch} epochs (patience={args.patience})")
|
||||
break
|
||||
|
||||
# Refresh pairs each epoch (random re-pairing)
|
||||
if hasattr(train_loader.dataset, "on_epoch_end"):
|
||||
train_loader.dataset.on_epoch_end()
|
||||
|
||||
# ── Plot training curves ───────────────────────────────────
|
||||
_plot_history(history, out / "training_curves.png")
|
||||
print(f"\n📊 Training curves saved → {out / 'training_curves.png'}")
|
||||
print(f"🏆 Best val loss: {best_val_loss:.4f}")
|
||||
|
||||
|
||||
def _plot_history(history, save_path):
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
||||
epochs = range(1, len(history["train_loss"]) + 1)
|
||||
|
||||
ax1.plot(epochs, history["train_loss"], label="Train")
|
||||
ax1.plot(epochs, history["val_loss"], label="Val")
|
||||
ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(True)
|
||||
|
||||
ax2.plot(epochs, history["train_acc"], label="Train")
|
||||
ax2.plot(epochs, history["val_acc"], label="Val")
|
||||
ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.set_ylim(0, 1)
|
||||
ax2.legend(); ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150)
|
||||
plt.close()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# CLI
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(description="Train Siamese Neural Network")
|
||||
p.add_argument("--train_dir", default=None, help="Path to training data root")
|
||||
p.add_argument("--val_dir", default=None, help="Path to validation data root")
|
||||
p.add_argument("--output_dir", default="checkpoints")
|
||||
p.add_argument("--img_size", type=int, default=224)
|
||||
p.add_argument("--batch_size", type=int, default=32)
|
||||
p.add_argument("--epochs", type=int, default=20)
|
||||
p.add_argument("--lr", type=float, default=1e-4)
|
||||
p.add_argument("--embedding_dim", type=int, default=128)
|
||||
p.add_argument("--pairs_per_epoch", type=int, default=5000)
|
||||
p.add_argument("--num_workers", type=int, default=4)
|
||||
p.add_argument("--patience", type=int, default=5,
|
||||
help="Early stopping patience (epochs)")
|
||||
|
||||
args = p.parse_args()
|
||||
train(args)
|
||||
Loading…
Reference in New Issue