listing-radar/inference.py

203 lines
8.5 KiB
Python

"""
Inference Module
================
Load a trained Siamese Network and compare images.
Usage:
python inference.py --img1 a.jpg --img2 b.jpg --checkpoint checkpoints/best_model.pth
python inference.py --query logo.png --gallery_dir /images/ # find best match
"""
import argparse
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from model import SiameseNet
from dataset import get_val_transform, IMAGENET_MEAN, IMAGENET_STD
# ─────────────────────────────────────────────────────────────
# Predictor class
# ─────────────────────────────────────────────────────────────
class SiamesePredictor:
"""
High-level inference wrapper.
"""
def __init__(
self,
checkpoint_path: Optional[str] = None,
embedding_dim: int = 128,
device: Optional[str] = None,
threshold: float = 0.6,
):
self.device = torch.device(device or (
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
))
self.threshold = threshold
self.transform = get_val_transform(224)
# Build model
self.model = SiameseNet(embedding_dim=embedding_dim, pretrained=False)
self.model.to(self.device)
if checkpoint_path and Path(checkpoint_path).exists():
ckpt = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(ckpt["model_state"])
print(f"✅ Loaded checkpoint from {checkpoint_path}")
print(f" Trained for {ckpt.get('epoch', '?')} epochs | "
f"val_acc = {ckpt.get('val_acc', 0):.3f}")
else:
print("⚠️ No checkpoint — running with random weights (demo only).")
self.model.eval()
def _load(self, path_or_pil) -> torch.Tensor:
if isinstance(path_or_pil, (str, Path)):
img = Image.open(path_or_pil).convert("RGB")
else:
img = path_or_pil.convert("RGB")
return self.transform(img).unsqueeze(0).to(self.device)
@torch.no_grad()
def compare(self, img1, img2) -> dict:
"""
Compare two images.
Returns dict with score, match, and euclidean distance.
"""
t1 = self._load(img1)
t2 = self._load(img2)
score, emb1, emb2 = self.model(t1, t2)
s = score.item()
dist = F.pairwise_distance(emb1, emb2).item()
return {
"similarity_score": round(s, 4),
"euclidean_dist": round(dist, 4),
"match": s >= self.threshold,
"confidence": "HIGH" if abs(s - 0.5) > 0.3 else
"MED" if abs(s - 0.5) > 0.15 else "LOW",
}
@torch.no_grad()
def embed(self, img) -> np.ndarray:
"""Get embedding vector for an image."""
t = self._load(img)
return self.model.get_embedding(t).cpu().numpy()[0]
@torch.no_grad()
def find_best_match(
self, query_img, gallery: List, top_k: int = 5
) -> List[Tuple[float, int]]:
"""
Find top-k most similar images from a gallery list.
gallery: list of image paths or PIL images.
Returns: list of (score, index) sorted descending.
"""
q_emb = torch.from_numpy(self.embed(query_img)).to(self.device)
scores = []
for idx, img in enumerate(gallery):
g_emb = torch.from_numpy(self.embed(img)).to(self.device)
cos = F.cosine_similarity(q_emb.unsqueeze(0), g_emb.unsqueeze(0)).item()
s = (cos + 1.0) / 2.0
scores.append((s, idx))
scores.sort(reverse=True)
return scores[:top_k]
# ─────────────────────────────────────────────────────────────
# Visualisation helper
# ─────────────────────────────────────────────────────────────
def visualise_comparison(img1_path, img2_path, result: dict, save_path="comparison.png"):
img1 = Image.open(img1_path).convert("RGB")
img2 = Image.open(img2_path).convert("RGB")
fig, axes = plt.subplots(1, 3, figsize=(14, 5))
fig.patch.set_facecolor("#0f0f1a")
for ax in axes:
ax.set_facecolor("#0f0f1a")
ax.axis("off")
axes[0].imshow(img1)
axes[0].set_title("Image 1", color="white", fontsize=13, pad=10)
axes[2].imshow(img2)
axes[2].set_title("Image 2", color="white", fontsize=13, pad=10)
# Middle panel: result
color = "#00e676" if result["match"] else "#ff1744"
symbol = "✓ MATCH" if result["match"] else "✗ NO MATCH"
axes[1].set_xlim(0, 1); axes[1].set_ylim(0, 1)
axes[1].add_patch(mpatches.FancyBboxPatch(
(0.1, 0.2), 0.8, 0.6, boxstyle="round,pad=0.05",
linewidth=3, edgecolor=color, facecolor="#1a1a2e"
))
axes[1].text(0.5, 0.75, symbol, ha="center", va="center", fontsize=20,
fontweight="bold", color=color)
axes[1].text(0.5, 0.55, f"Score: {result['similarity_score']:.3f}",
ha="center", va="center", fontsize=14, color="white")
axes[1].text(0.5, 0.40, f"Dist: {result['euclidean_dist']:.3f}",
ha="center", va="center", fontsize=11, color="#aaaaaa")
axes[1].text(0.5, 0.28, f"Conf: {result['confidence']}",
ha="center", va="center", fontsize=11, color="#aaaaaa")
plt.tight_layout(pad=2)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"📸 Visualisation saved → {save_path}")
# ─────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Siamese Network Inference")
p.add_argument("--img1", required=True, help="Path to image 1")
p.add_argument("--img2", default=None, help="Path to image 2 (comparison mode)")
p.add_argument("--gallery_dir", default=None, help="Directory to search (gallery mode)")
p.add_argument("--checkpoint", default="checkpoints/best_model.pth")
p.add_argument("--threshold", type=float, default=0.6, help="Match threshold")
p.add_argument("--top_k", type=int, default=5)
p.add_argument("--output", default="comparison.png")
args = p.parse_args()
predictor = SiamesePredictor(
checkpoint_path=args.checkpoint,
threshold=args.threshold,
)
if args.img2:
# ── Compare two images ────────────────────────────────
result = predictor.compare(args.img1, args.img2)
print("\n" + "" * 45)
print(f" Similarity score : {result['similarity_score']}")
print(f" Euclidean dist : {result['euclidean_dist']}")
print(f" Match : {'YES ✓' if result['match'] else 'NO ✗'}")
print(f" Confidence : {result['confidence']}")
print("" * 45)
visualise_comparison(args.img1, args.img2, result, args.output)
elif args.gallery_dir:
# ── Gallery search ────────────────────────────────────
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
gallery = [p for p in Path(args.gallery_dir).iterdir() if p.suffix.lower() in exts]
matches = predictor.find_best_match(args.img1, gallery, top_k=args.top_k)
print(f"\nTop-{args.top_k} matches for {args.img1}:")
for rank, (score, idx) in enumerate(matches, 1):
verdict = "✓ MATCH" if score >= args.threshold else " diff "
print(f" {rank}. {verdict} score={score:.4f} {gallery[idx]}")
else:
print("Provide --img2 or --gallery_dir")