203 lines
8.5 KiB
Python
203 lines
8.5 KiB
Python
"""
|
|
Inference Module
|
|
================
|
|
Load a trained Siamese Network and compare images.
|
|
|
|
Usage:
|
|
python inference.py --img1 a.jpg --img2 b.jpg --checkpoint checkpoints/best_model.pth
|
|
python inference.py --query logo.png --gallery_dir /images/ # find best match
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as mpatches
|
|
import numpy as np
|
|
|
|
from model import SiameseNet
|
|
from dataset import get_val_transform, IMAGENET_MEAN, IMAGENET_STD
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# Predictor class
|
|
# ─────────────────────────────────────────────────────────────
|
|
class SiamesePredictor:
|
|
"""
|
|
High-level inference wrapper.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint_path: Optional[str] = None,
|
|
embedding_dim: int = 128,
|
|
device: Optional[str] = None,
|
|
threshold: float = 0.6,
|
|
):
|
|
self.device = torch.device(device or (
|
|
"cuda" if torch.cuda.is_available() else
|
|
"mps" if torch.backends.mps.is_available() else
|
|
"cpu"
|
|
))
|
|
self.threshold = threshold
|
|
self.transform = get_val_transform(224)
|
|
|
|
# Build model
|
|
self.model = SiameseNet(embedding_dim=embedding_dim, pretrained=False)
|
|
self.model.to(self.device)
|
|
|
|
if checkpoint_path and Path(checkpoint_path).exists():
|
|
ckpt = torch.load(checkpoint_path, map_location=self.device)
|
|
self.model.load_state_dict(ckpt["model_state"])
|
|
print(f"✅ Loaded checkpoint from {checkpoint_path}")
|
|
print(f" Trained for {ckpt.get('epoch', '?')} epochs | "
|
|
f"val_acc = {ckpt.get('val_acc', 0):.3f}")
|
|
else:
|
|
print("⚠️ No checkpoint — running with random weights (demo only).")
|
|
|
|
self.model.eval()
|
|
|
|
def _load(self, path_or_pil) -> torch.Tensor:
|
|
if isinstance(path_or_pil, (str, Path)):
|
|
img = Image.open(path_or_pil).convert("RGB")
|
|
else:
|
|
img = path_or_pil.convert("RGB")
|
|
return self.transform(img).unsqueeze(0).to(self.device)
|
|
|
|
@torch.no_grad()
|
|
def compare(self, img1, img2) -> dict:
|
|
"""
|
|
Compare two images.
|
|
Returns dict with score, match, and euclidean distance.
|
|
"""
|
|
t1 = self._load(img1)
|
|
t2 = self._load(img2)
|
|
score, emb1, emb2 = self.model(t1, t2)
|
|
s = score.item()
|
|
dist = F.pairwise_distance(emb1, emb2).item()
|
|
return {
|
|
"similarity_score": round(s, 4),
|
|
"euclidean_dist": round(dist, 4),
|
|
"match": s >= self.threshold,
|
|
"confidence": "HIGH" if abs(s - 0.5) > 0.3 else
|
|
"MED" if abs(s - 0.5) > 0.15 else "LOW",
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def embed(self, img) -> np.ndarray:
|
|
"""Get embedding vector for an image."""
|
|
t = self._load(img)
|
|
return self.model.get_embedding(t).cpu().numpy()[0]
|
|
|
|
@torch.no_grad()
|
|
def find_best_match(
|
|
self, query_img, gallery: List, top_k: int = 5
|
|
) -> List[Tuple[float, int]]:
|
|
"""
|
|
Find top-k most similar images from a gallery list.
|
|
gallery: list of image paths or PIL images.
|
|
Returns: list of (score, index) sorted descending.
|
|
"""
|
|
q_emb = torch.from_numpy(self.embed(query_img)).to(self.device)
|
|
scores = []
|
|
for idx, img in enumerate(gallery):
|
|
g_emb = torch.from_numpy(self.embed(img)).to(self.device)
|
|
cos = F.cosine_similarity(q_emb.unsqueeze(0), g_emb.unsqueeze(0)).item()
|
|
s = (cos + 1.0) / 2.0
|
|
scores.append((s, idx))
|
|
scores.sort(reverse=True)
|
|
return scores[:top_k]
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# Visualisation helper
|
|
# ─────────────────────────────────────────────────────────────
|
|
def visualise_comparison(img1_path, img2_path, result: dict, save_path="comparison.png"):
|
|
img1 = Image.open(img1_path).convert("RGB")
|
|
img2 = Image.open(img2_path).convert("RGB")
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(14, 5))
|
|
fig.patch.set_facecolor("#0f0f1a")
|
|
|
|
for ax in axes:
|
|
ax.set_facecolor("#0f0f1a")
|
|
ax.axis("off")
|
|
|
|
axes[0].imshow(img1)
|
|
axes[0].set_title("Image 1", color="white", fontsize=13, pad=10)
|
|
|
|
axes[2].imshow(img2)
|
|
axes[2].set_title("Image 2", color="white", fontsize=13, pad=10)
|
|
|
|
# Middle panel: result
|
|
color = "#00e676" if result["match"] else "#ff1744"
|
|
symbol = "✓ MATCH" if result["match"] else "✗ NO MATCH"
|
|
axes[1].set_xlim(0, 1); axes[1].set_ylim(0, 1)
|
|
axes[1].add_patch(mpatches.FancyBboxPatch(
|
|
(0.1, 0.2), 0.8, 0.6, boxstyle="round,pad=0.05",
|
|
linewidth=3, edgecolor=color, facecolor="#1a1a2e"
|
|
))
|
|
axes[1].text(0.5, 0.75, symbol, ha="center", va="center", fontsize=20,
|
|
fontweight="bold", color=color)
|
|
axes[1].text(0.5, 0.55, f"Score: {result['similarity_score']:.3f}",
|
|
ha="center", va="center", fontsize=14, color="white")
|
|
axes[1].text(0.5, 0.40, f"Dist: {result['euclidean_dist']:.3f}",
|
|
ha="center", va="center", fontsize=11, color="#aaaaaa")
|
|
axes[1].text(0.5, 0.28, f"Conf: {result['confidence']}",
|
|
ha="center", va="center", fontsize=11, color="#aaaaaa")
|
|
|
|
plt.tight_layout(pad=2)
|
|
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
print(f"📸 Visualisation saved → {save_path}")
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# CLI
|
|
# ─────────────────────────────────────────────────────────────
|
|
if __name__ == "__main__":
|
|
p = argparse.ArgumentParser(description="Siamese Network Inference")
|
|
p.add_argument("--img1", required=True, help="Path to image 1")
|
|
p.add_argument("--img2", default=None, help="Path to image 2 (comparison mode)")
|
|
p.add_argument("--gallery_dir", default=None, help="Directory to search (gallery mode)")
|
|
p.add_argument("--checkpoint", default="checkpoints/best_model.pth")
|
|
p.add_argument("--threshold", type=float, default=0.6, help="Match threshold")
|
|
p.add_argument("--top_k", type=int, default=5)
|
|
p.add_argument("--output", default="comparison.png")
|
|
args = p.parse_args()
|
|
|
|
predictor = SiamesePredictor(
|
|
checkpoint_path=args.checkpoint,
|
|
threshold=args.threshold,
|
|
)
|
|
|
|
if args.img2:
|
|
# ── Compare two images ────────────────────────────────
|
|
result = predictor.compare(args.img1, args.img2)
|
|
print("\n" + "═" * 45)
|
|
print(f" Similarity score : {result['similarity_score']}")
|
|
print(f" Euclidean dist : {result['euclidean_dist']}")
|
|
print(f" Match : {'YES ✓' if result['match'] else 'NO ✗'}")
|
|
print(f" Confidence : {result['confidence']}")
|
|
print("═" * 45)
|
|
visualise_comparison(args.img1, args.img2, result, args.output)
|
|
|
|
elif args.gallery_dir:
|
|
# ── Gallery search ────────────────────────────────────
|
|
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
|
gallery = [p for p in Path(args.gallery_dir).iterdir() if p.suffix.lower() in exts]
|
|
matches = predictor.find_best_match(args.img1, gallery, top_k=args.top_k)
|
|
print(f"\nTop-{args.top_k} matches for {args.img1}:")
|
|
for rank, (score, idx) in enumerate(matches, 1):
|
|
verdict = "✓ MATCH" if score >= args.threshold else " diff "
|
|
print(f" {rank}. {verdict} score={score:.4f} {gallery[idx]}")
|
|
else:
|
|
print("Provide --img2 or --gallery_dir")
|