listing-radar/model_export/model.py

114 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Siamese Neural Network for Image Similarity Matching
======================================================
Robust to: cropping, resolution differences, noise, rotation, color jitter.
Architecture:
- Shared CNN encoder (ResNet-18 backbone, pretrained)
- L2-normalized embedding head
- Contrastive Loss / Binary Cross-Entropy with cosine similarity
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# ─────────────────────────────────────────────────────────────
# Embedding Network (shared weights for both branches)
# ─────────────────────────────────────────────────────────────
class EmbeddingNet(nn.Module):
"""
ResNet-18 backbone with a custom projection head.
Outputs a 128-dim L2-normalised embedding.
"""
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
super().__init__()
# Load pretrained ResNet-18
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
backbone = models.resnet18(weights=weights)
# Remove the final FC layer; keep everything up to avgpool
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
# Projection head: 512 → 256 → embedding_dim
self.projector = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, embedding_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, 3, H, W)
feats = self.feature_extractor(x) # (B, 512, 1, 1)
feats = feats.view(feats.size(0), -1) # (B, 512)
emb = self.projector(feats) # (B, embedding_dim)
return F.normalize(emb, p=2, dim=1) # L2-normalised
# ─────────────────────────────────────────────────────────────
# Siamese Network
# ─────────────────────────────────────────────────────────────
class SiameseNet(nn.Module):
"""
Takes two images, returns:
• similarity score ∈ [0, 1] (1 = same, 0 = different)
• both L2-normalised embeddings
"""
def __init__(self, embedding_dim: int = 128, pretrained: bool = True):
super().__init__()
self.encoder = EmbeddingNet(embedding_dim, pretrained)
def forward(self, img1: torch.Tensor, img2: torch.Tensor):
emb1 = self.encoder(img1)
emb2 = self.encoder(img2)
# Cosine similarity → [1, 1], rescale to [0, 1]
cos_sim = F.cosine_similarity(emb1, emb2, dim=1)
score = (cos_sim + 1.0) / 2.0 # (B,)
return score, emb1, emb2
def get_embedding(self, img: torch.Tensor) -> torch.Tensor:
"""Get embedding for a single image (useful at inference time)."""
return self.encoder(img)
# ─────────────────────────────────────────────────────────────
# Loss Functions
# ─────────────────────────────────────────────────────────────
class ContrastiveLoss(nn.Module):
"""
Contrastive Loss (LeCun 2005).
label = 1 → same pair (pull embeddings together)
label = 0 → diff pair (push embeddings apart, beyond margin)
"""
def __init__(self, margin: float = 1.0):
super().__init__()
self.margin = margin
def forward(self, emb1, emb2, label):
dist = F.pairwise_distance(emb1, emb2) # Euclidean distance
same_loss = label * dist.pow(2)
diff_loss = (1 - label) * F.relu(self.margin - dist).pow(2)
return (same_loss + diff_loss).mean()
class BCECosineLoss(nn.Module):
"""
Binary Cross-Entropy on cosine-similarity score.
Simple and very effective for this task.
"""
def __init__(self):
super().__init__()
self.bce = nn.BCELoss()
def forward(self, score, label):
return self.bce(score, label.float())