""" Training Script =============== Usage: python train.py # synthetic demo python train.py --train_dir data/train --val_dir data/val python train.py --train_dir data/train --epochs 30 --batch_size 64 """ import argparse import os import time from pathlib import Path import torch import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR from tqdm import tqdm import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from model import SiameseNet, ContrastiveLoss, BCECosineLoss from dataset import build_dataloaders from dotenv import load_dotenv load_dotenv() # ───────────────────────────────────────────────────────────── # Metrics # ───────────────────────────────────────────────────────────── def compute_accuracy(scores: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5): preds = (scores >= threshold).float() correct = (preds == labels).sum().item() return correct / len(labels) # ───────────────────────────────────────────────────────────── # One epoch # ───────────────────────────────────────────────────────────── def run_epoch(model, loader, criterion, optimizer, device, train: bool): model.train() if train else model.eval() total_loss, total_acc, n = 0.0, 0.0, 0 ctx = torch.enable_grad() if train else torch.no_grad() with ctx: for img1, img2, labels in tqdm(loader, desc="train" if train else "val ", leave=False): img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) scores, emb1, emb2 = model(img1, img2) # Use BCECosineLoss by default; swap to ContrastiveLoss if you prefer loss = criterion(scores, labels) if train: optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() acc = compute_accuracy(scores.detach().cpu(), labels.cpu()) bs = img1.size(0) total_loss += loss.item() * bs total_acc += acc * bs n += bs return total_loss / n, total_acc / n # ───────────────────────────────────────────────────────────── # Training loop # ───────────────────────────────────────────────────────────── def train(args): device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"🖥️ Device: {device}") args.train_dir=os.getenv("DATASET") args.val_dir=os.getenv("VAL_DATASET") # ── Data ────────────────────────────────────────────────── train_loader, val_loader = build_dataloaders( train_root = args.train_dir, val_root = args.val_dir, img_size = args.img_size, batch_size = args.batch_size, num_workers = args.num_workers, pairs_per_epoch = args.pairs_per_epoch, ) print(f"📦 Train batches: {len(train_loader)} | Val batches: {len(val_loader)}") # ── Model ───────────────────────────────────────────────── model = SiameseNet(embedding_dim=args.embedding_dim, pretrained=True).to(device) criterion = BCECosineLoss() optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) # ── Output dir ──────────────────────────────────────────── out = Path(args.output_dir) out.mkdir(parents=True, exist_ok=True) history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []} best_val_loss = float("inf") patience_counter = 0 print("\n🚀 Training started\n") for epoch in range(1, args.epochs + 1): t0 = time.time() tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, train=True) va_loss, va_acc = run_epoch(model, val_loader, criterion, None, device, train=False) scheduler.step() history["train_loss"].append(tr_loss) history["val_loss"].append(va_loss) history["train_acc"].append(tr_acc) history["val_acc"].append(va_acc) elapsed = time.time() - t0 print( f"Epoch {epoch:3d}/{args.epochs} │ " f"Train loss={tr_loss:.4f} acc={tr_acc:.3f} │ " f"Val loss={va_loss:.4f} acc={va_acc:.3f} │ " f"LR={scheduler.get_last_lr()[0]:.2e} [{elapsed:.1f}s]" ) # ── Checkpoint ──────────────────────────────────────── if va_loss < best_val_loss: best_val_loss = va_loss patience_counter = 0 ckpt_path = out / "best_model.pth" torch.save({ "epoch": epoch, "model_state": model.state_dict(), "optim_state": optimizer.state_dict(), "val_loss": va_loss, "val_acc": va_acc, "args": vars(args), }, ckpt_path) print(f" ✅ Best model saved → {ckpt_path}") else: patience_counter += 1 if patience_counter >= args.patience: print(f"\n⏹ Early stopping after {epoch} epochs (patience={args.patience})") break # Refresh pairs each epoch (random re-pairing) if hasattr(train_loader.dataset, "on_epoch_end"): train_loader.dataset.on_epoch_end() # ── Plot training curves ─────────────────────────────────── _plot_history(history, out / "training_curves.png") print(f"\n📊 Training curves saved → {out / 'training_curves.png'}") print(f"🏆 Best val loss: {best_val_loss:.4f}") def _plot_history(history, save_path): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) epochs = range(1, len(history["train_loss"]) + 1) ax1.plot(epochs, history["train_loss"], label="Train") ax1.plot(epochs, history["val_loss"], label="Val") ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(True) ax2.plot(epochs, history["train_acc"], label="Train") ax2.plot(epochs, history["val_acc"], label="Val") ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.set_ylim(0, 1) ax2.legend(); ax2.grid(True) plt.tight_layout() plt.savefig(save_path, dpi=150) plt.close() # ───────────────────────────────────────────────────────────── # CLI # ───────────────────────────────────────────────────────────── if __name__ == "__main__": p = argparse.ArgumentParser(description="Train Siamese Neural Network") p.add_argument("--train_dir", default=None, help="Path to training data root") p.add_argument("--val_dir", default=None, help="Path to validation data root") p.add_argument("--output_dir", default="checkpoints") p.add_argument("--img_size", type=int, default=224) p.add_argument("--batch_size", type=int, default=32) p.add_argument("--epochs", type=int, default=20) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--embedding_dim", type=int, default=128) p.add_argument("--pairs_per_epoch", type=int, default=5000) p.add_argument("--num_workers", type=int, default=4) p.add_argument("--patience", type=int, default=5, help="Early stopping patience (epochs)") args = p.parse_args() train(args)