114 lines
4.5 KiB
Python
114 lines
4.5 KiB
Python
"""
|
||
Siamese Neural Network for Image Similarity Matching
|
||
======================================================
|
||
Robust to: cropping, resolution differences, noise, rotation, color jitter.
|
||
|
||
Architecture:
|
||
- Shared CNN encoder (ResNet-18 backbone, pretrained)
|
||
- L2-normalized embedding head
|
||
- Contrastive Loss / Binary Cross-Entropy with cosine similarity
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torchvision.models as models
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Embedding Network (shared weights for both branches)
|
||
# ─────────────────────────────────────────────────────────────
|
||
class EmbeddingNet(nn.Module):
|
||
"""
|
||
ResNet-18 backbone with a custom projection head.
|
||
Outputs a 128-dim L2-normalised embedding.
|
||
"""
|
||
|
||
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
|
||
super().__init__()
|
||
|
||
# Load pretrained ResNet-18
|
||
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
|
||
backbone = models.resnet18(weights=weights)
|
||
|
||
# Remove the final FC layer; keep everything up to avgpool
|
||
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
|
||
|
||
# Projection head: 512 → 256 → embedding_dim
|
||
self.projector = nn.Sequential(
|
||
nn.Linear(512, 256),
|
||
nn.BatchNorm1d(256),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout(0.3),
|
||
nn.Linear(256, embedding_dim),
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
# x: (B, 3, H, W)
|
||
feats = self.feature_extractor(x) # (B, 512, 1, 1)
|
||
feats = feats.view(feats.size(0), -1) # (B, 512)
|
||
emb = self.projector(feats) # (B, embedding_dim)
|
||
return F.normalize(emb, p=2, dim=1) # L2-normalised
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Siamese Network
|
||
# ─────────────────────────────────────────────────────────────
|
||
class SiameseNet(nn.Module):
|
||
"""
|
||
Takes two images, returns:
|
||
• similarity score ∈ [0, 1] (1 = same, 0 = different)
|
||
• both L2-normalised embeddings
|
||
"""
|
||
|
||
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
|
||
super().__init__()
|
||
self.encoder = EmbeddingNet(embedding_dim, pretrained)
|
||
|
||
def forward(self, img1: torch.Tensor, img2: torch.Tensor):
|
||
emb1 = self.encoder(img1)
|
||
emb2 = self.encoder(img2)
|
||
# Cosine similarity → [−1, 1], rescale to [0, 1]
|
||
cos_sim = F.cosine_similarity(emb1, emb2, dim=1)
|
||
score = (cos_sim + 1.0) / 2.0 # (B,)
|
||
return score, emb1, emb2
|
||
|
||
def get_embedding(self, img: torch.Tensor) -> torch.Tensor:
|
||
"""Get embedding for a single image (useful at inference time)."""
|
||
return self.encoder(img)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────
|
||
# Loss Functions
|
||
# ─────────────────────────────────────────────────────────────
|
||
class ContrastiveLoss(nn.Module):
|
||
"""
|
||
Contrastive Loss (LeCun 2005).
|
||
label = 1 → same pair (pull embeddings together)
|
||
label = 0 → diff pair (push embeddings apart, beyond margin)
|
||
"""
|
||
|
||
def __init__(self, margin: float = 1.0):
|
||
super().__init__()
|
||
self.margin = margin
|
||
|
||
def forward(self, emb1, emb2, label):
|
||
dist = F.pairwise_distance(emb1, emb2) # Euclidean distance
|
||
same_loss = label * dist.pow(2)
|
||
diff_loss = (1 - label) * F.relu(self.margin - dist).pow(2)
|
||
return (same_loss + diff_loss).mean()
|
||
|
||
|
||
class BCECosineLoss(nn.Module):
|
||
"""
|
||
Binary Cross-Entropy on cosine-similarity score.
|
||
Simple and very effective for this task.
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.bce = nn.BCELoss()
|
||
|
||
def forward(self, score, label):
|
||
return self.bce(score, label.float())
|