196 lines
8.8 KiB
Python
196 lines
8.8 KiB
Python
"""
|
|
Training Script
|
|
===============
|
|
Usage:
|
|
python train.py # synthetic demo
|
|
python train.py --train_dir data/train --val_dir data/val
|
|
python train.py --train_dir data/train --epochs 30 --batch_size 64
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from tqdm import tqdm
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
from model import SiameseNet, ContrastiveLoss, BCECosineLoss
|
|
from dataset import build_dataloaders
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# Metrics
|
|
# ─────────────────────────────────────────────────────────────
|
|
def compute_accuracy(scores: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
|
|
preds = (scores >= threshold).float()
|
|
correct = (preds == labels).sum().item()
|
|
return correct / len(labels)
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# One epoch
|
|
# ─────────────────────────────────────────────────────────────
|
|
def run_epoch(model, loader, criterion, optimizer, device, train: bool):
|
|
model.train() if train else model.eval()
|
|
total_loss, total_acc, n = 0.0, 0.0, 0
|
|
|
|
ctx = torch.enable_grad() if train else torch.no_grad()
|
|
with ctx:
|
|
for img1, img2, labels in tqdm(loader, desc="train" if train else "val ", leave=False):
|
|
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
|
|
|
|
scores, emb1, emb2 = model(img1, img2)
|
|
|
|
# Use BCECosineLoss by default; swap to ContrastiveLoss if you prefer
|
|
loss = criterion(scores, labels)
|
|
|
|
if train:
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
acc = compute_accuracy(scores.detach().cpu(), labels.cpu())
|
|
bs = img1.size(0)
|
|
total_loss += loss.item() * bs
|
|
total_acc += acc * bs
|
|
n += bs
|
|
|
|
return total_loss / n, total_acc / n
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# Training loop
|
|
# ─────────────────────────────────────────────────────────────
|
|
def train(args):
|
|
device = torch.device(
|
|
"cuda" if torch.cuda.is_available() else
|
|
"mps" if torch.backends.mps.is_available() else
|
|
"cpu"
|
|
)
|
|
print(f"🖥️ Device: {device}")
|
|
args.train_dir=os.getenv("DATASET")
|
|
args.val_dir=os.getenv("VAL_DATASET")
|
|
# ── Data ──────────────────────────────────────────────────
|
|
train_loader, val_loader = build_dataloaders(
|
|
train_root = args.train_dir,
|
|
val_root = args.val_dir,
|
|
img_size = args.img_size,
|
|
batch_size = args.batch_size,
|
|
num_workers = args.num_workers,
|
|
pairs_per_epoch = args.pairs_per_epoch,
|
|
)
|
|
print(f"📦 Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
|
|
|
|
# ── Model ─────────────────────────────────────────────────
|
|
model = SiameseNet(embedding_dim=args.embedding_dim, pretrained=True).to(device)
|
|
criterion = BCECosineLoss()
|
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
|
|
|
# ── Output dir ────────────────────────────────────────────
|
|
out = Path(args.output_dir)
|
|
out.mkdir(parents=True, exist_ok=True)
|
|
|
|
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
|
|
best_val_loss = float("inf")
|
|
patience_counter = 0
|
|
|
|
print("\n🚀 Training started\n")
|
|
for epoch in range(1, args.epochs + 1):
|
|
t0 = time.time()
|
|
|
|
tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, train=True)
|
|
va_loss, va_acc = run_epoch(model, val_loader, criterion, None, device, train=False)
|
|
|
|
scheduler.step()
|
|
|
|
history["train_loss"].append(tr_loss)
|
|
history["val_loss"].append(va_loss)
|
|
history["train_acc"].append(tr_acc)
|
|
history["val_acc"].append(va_acc)
|
|
|
|
elapsed = time.time() - t0
|
|
print(
|
|
f"Epoch {epoch:3d}/{args.epochs} │ "
|
|
f"Train loss={tr_loss:.4f} acc={tr_acc:.3f} │ "
|
|
f"Val loss={va_loss:.4f} acc={va_acc:.3f} │ "
|
|
f"LR={scheduler.get_last_lr()[0]:.2e} [{elapsed:.1f}s]"
|
|
)
|
|
|
|
# ── Checkpoint ────────────────────────────────────────
|
|
if va_loss < best_val_loss:
|
|
best_val_loss = va_loss
|
|
patience_counter = 0
|
|
ckpt_path = out / "best_model.pth"
|
|
torch.save({
|
|
"epoch": epoch,
|
|
"model_state": model.state_dict(),
|
|
"optim_state": optimizer.state_dict(),
|
|
"val_loss": va_loss,
|
|
"val_acc": va_acc,
|
|
"args": vars(args),
|
|
}, ckpt_path)
|
|
print(f" ✅ Best model saved → {ckpt_path}")
|
|
else:
|
|
patience_counter += 1
|
|
if patience_counter >= args.patience:
|
|
print(f"\n⏹ Early stopping after {epoch} epochs (patience={args.patience})")
|
|
break
|
|
|
|
# Refresh pairs each epoch (random re-pairing)
|
|
if hasattr(train_loader.dataset, "on_epoch_end"):
|
|
train_loader.dataset.on_epoch_end()
|
|
|
|
# ── Plot training curves ───────────────────────────────────
|
|
_plot_history(history, out / "training_curves.png")
|
|
print(f"\n📊 Training curves saved → {out / 'training_curves.png'}")
|
|
print(f"🏆 Best val loss: {best_val_loss:.4f}")
|
|
|
|
|
|
def _plot_history(history, save_path):
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
|
epochs = range(1, len(history["train_loss"]) + 1)
|
|
|
|
ax1.plot(epochs, history["train_loss"], label="Train")
|
|
ax1.plot(epochs, history["val_loss"], label="Val")
|
|
ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(True)
|
|
|
|
ax2.plot(epochs, history["train_acc"], label="Train")
|
|
ax2.plot(epochs, history["val_acc"], label="Val")
|
|
ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.set_ylim(0, 1)
|
|
ax2.legend(); ax2.grid(True)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(save_path, dpi=150)
|
|
plt.close()
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────
|
|
# CLI
|
|
# ─────────────────────────────────────────────────────────────
|
|
if __name__ == "__main__":
|
|
p = argparse.ArgumentParser(description="Train Siamese Neural Network")
|
|
p.add_argument("--train_dir", default=None, help="Path to training data root")
|
|
p.add_argument("--val_dir", default=None, help="Path to validation data root")
|
|
p.add_argument("--output_dir", default="checkpoints")
|
|
p.add_argument("--img_size", type=int, default=224)
|
|
p.add_argument("--batch_size", type=int, default=32)
|
|
p.add_argument("--epochs", type=int, default=20)
|
|
p.add_argument("--lr", type=float, default=1e-4)
|
|
p.add_argument("--embedding_dim", type=int, default=128)
|
|
p.add_argument("--pairs_per_epoch", type=int, default=5000)
|
|
p.add_argument("--num_workers", type=int, default=4)
|
|
p.add_argument("--patience", type=int, default=5,
|
|
help="Early stopping patience (epochs)")
|
|
|
|
args = p.parse_args()
|
|
train(args)
|