""" Inference Module ================ Load a trained Siamese Network and compare images. Usage: python inference.py --img1 a.jpg --img2 b.jpg --checkpoint checkpoints/best_model.pth python inference.py --query logo.png --gallery_dir /images/ # find best match """ import argparse from pathlib import Path from typing import List, Tuple, Optional import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np from model import SiameseNet from dataset import get_val_transform, IMAGENET_MEAN, IMAGENET_STD # ───────────────────────────────────────────────────────────── # Predictor class # ───────────────────────────────────────────────────────────── class SiamesePredictor: """ High-level inference wrapper. """ def __init__( self, checkpoint_path: Optional[str] = None, embedding_dim: int = 128, device: Optional[str] = None, threshold: float = 0.6, ): self.device = torch.device(device or ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" )) self.threshold = threshold self.transform = get_val_transform(224) # Build model self.model = SiameseNet(embedding_dim=embedding_dim, pretrained=False) self.model.to(self.device) if checkpoint_path and Path(checkpoint_path).exists(): ckpt = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(ckpt["model_state"]) print(f"✅ Loaded checkpoint from {checkpoint_path}") print(f" Trained for {ckpt.get('epoch', '?')} epochs | " f"val_acc = {ckpt.get('val_acc', 0):.3f}") else: print("⚠️ No checkpoint — running with random weights (demo only).") self.model.eval() def _load(self, path_or_pil) -> torch.Tensor: if isinstance(path_or_pil, (str, Path)): img = Image.open(path_or_pil).convert("RGB") else: img = path_or_pil.convert("RGB") return self.transform(img).unsqueeze(0).to(self.device) @torch.no_grad() def compare(self, img1, img2) -> dict: """ Compare two images. Returns dict with score, match, and euclidean distance. """ t1 = self._load(img1) t2 = self._load(img2) score, emb1, emb2 = self.model(t1, t2) s = score.item() dist = F.pairwise_distance(emb1, emb2).item() return { "similarity_score": round(s, 4), "euclidean_dist": round(dist, 4), "match": s >= self.threshold, "confidence": "HIGH" if abs(s - 0.5) > 0.3 else "MED" if abs(s - 0.5) > 0.15 else "LOW", } @torch.no_grad() def embed(self, img) -> np.ndarray: """Get embedding vector for an image.""" t = self._load(img) return self.model.get_embedding(t).cpu().numpy()[0] @torch.no_grad() def find_best_match( self, query_img, gallery: List, top_k: int = 5 ) -> List[Tuple[float, int]]: """ Find top-k most similar images from a gallery list. gallery: list of image paths or PIL images. Returns: list of (score, index) sorted descending. """ q_emb = torch.from_numpy(self.embed(query_img)).to(self.device) scores = [] for idx, img in enumerate(gallery): g_emb = torch.from_numpy(self.embed(img)).to(self.device) cos = F.cosine_similarity(q_emb.unsqueeze(0), g_emb.unsqueeze(0)).item() s = (cos + 1.0) / 2.0 scores.append((s, idx)) scores.sort(reverse=True) return scores[:top_k] # ───────────────────────────────────────────────────────────── # Visualisation helper # ───────────────────────────────────────────────────────────── def visualise_comparison(img1_path, img2_path, result: dict, save_path="comparison.png"): img1 = Image.open(img1_path).convert("RGB") img2 = Image.open(img2_path).convert("RGB") fig, axes = plt.subplots(1, 3, figsize=(14, 5)) fig.patch.set_facecolor("#0f0f1a") for ax in axes: ax.set_facecolor("#0f0f1a") ax.axis("off") axes[0].imshow(img1) axes[0].set_title("Image 1", color="white", fontsize=13, pad=10) axes[2].imshow(img2) axes[2].set_title("Image 2", color="white", fontsize=13, pad=10) # Middle panel: result color = "#00e676" if result["match"] else "#ff1744" symbol = "✓ MATCH" if result["match"] else "✗ NO MATCH" axes[1].set_xlim(0, 1); axes[1].set_ylim(0, 1) axes[1].add_patch(mpatches.FancyBboxPatch( (0.1, 0.2), 0.8, 0.6, boxstyle="round,pad=0.05", linewidth=3, edgecolor=color, facecolor="#1a1a2e" )) axes[1].text(0.5, 0.75, symbol, ha="center", va="center", fontsize=20, fontweight="bold", color=color) axes[1].text(0.5, 0.55, f"Score: {result['similarity_score']:.3f}", ha="center", va="center", fontsize=14, color="white") axes[1].text(0.5, 0.40, f"Dist: {result['euclidean_dist']:.3f}", ha="center", va="center", fontsize=11, color="#aaaaaa") axes[1].text(0.5, 0.28, f"Conf: {result['confidence']}", ha="center", va="center", fontsize=11, color="#aaaaaa") plt.tight_layout(pad=2) plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f"📸 Visualisation saved → {save_path}") # ───────────────────────────────────────────────────────────── # CLI # ───────────────────────────────────────────────────────────── if __name__ == "__main__": p = argparse.ArgumentParser(description="Siamese Network Inference") p.add_argument("--img1", required=True, help="Path to image 1") p.add_argument("--img2", default=None, help="Path to image 2 (comparison mode)") p.add_argument("--gallery_dir", default=None, help="Directory to search (gallery mode)") p.add_argument("--checkpoint", default="checkpoints/best_model.pth") p.add_argument("--threshold", type=float, default=0.6, help="Match threshold") p.add_argument("--top_k", type=int, default=5) p.add_argument("--output", default="comparison.png") args = p.parse_args() predictor = SiamesePredictor( checkpoint_path=args.checkpoint, threshold=args.threshold, ) if args.img2: # ── Compare two images ──────────────────────────────── result = predictor.compare(args.img1, args.img2) print("\n" + "═" * 45) print(f" Similarity score : {result['similarity_score']}") print(f" Euclidean dist : {result['euclidean_dist']}") print(f" Match : {'YES ✓' if result['match'] else 'NO ✗'}") print(f" Confidence : {result['confidence']}") print("═" * 45) visualise_comparison(args.img1, args.img2, result, args.output) elif args.gallery_dir: # ── Gallery search ──────────────────────────────────── exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} gallery = [p for p in Path(args.gallery_dir).iterdir() if p.suffix.lower() in exts] matches = predictor.find_best_match(args.img1, gallery, top_k=args.top_k) print(f"\nTop-{args.top_k} matches for {args.img1}:") for rank, (score, idx) in enumerate(matches, 1): verdict = "✓ MATCH" if score >= args.threshold else " diff " print(f" {rank}. {verdict} score={score:.4f} {gallery[idx]}") else: print("Provide --img2 or --gallery_dir")