listing-radar/model_export/train.py

196 lines
8.8 KiB
Python

"""
Training Script
===============
Usage:
python train.py # synthetic demo
python train.py --train_dir data/train --val_dir data/val
python train.py --train_dir data/train --epochs 30 --batch_size 64
"""
import argparse
import os
import time
from pathlib import Path
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from model import SiameseNet, ContrastiveLoss, BCECosineLoss
from dataset import build_dataloaders
from dotenv import load_dotenv
load_dotenv()
# ─────────────────────────────────────────────────────────────
# Metrics
# ─────────────────────────────────────────────────────────────
def compute_accuracy(scores: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
preds = (scores >= threshold).float()
correct = (preds == labels).sum().item()
return correct / len(labels)
# ─────────────────────────────────────────────────────────────
# One epoch
# ─────────────────────────────────────────────────────────────
def run_epoch(model, loader, criterion, optimizer, device, train: bool):
model.train() if train else model.eval()
total_loss, total_acc, n = 0.0, 0.0, 0
ctx = torch.enable_grad() if train else torch.no_grad()
with ctx:
for img1, img2, labels in tqdm(loader, desc="train" if train else "val ", leave=False):
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
scores, emb1, emb2 = model(img1, img2)
# Use BCECosineLoss by default; swap to ContrastiveLoss if you prefer
loss = criterion(scores, labels)
if train:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
acc = compute_accuracy(scores.detach().cpu(), labels.cpu())
bs = img1.size(0)
total_loss += loss.item() * bs
total_acc += acc * bs
n += bs
return total_loss / n, total_acc / n
# ─────────────────────────────────────────────────────────────
# Training loop
# ─────────────────────────────────────────────────────────────
def train(args):
device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
print(f"🖥️ Device: {device}")
args.train_dir=os.getenv("DATASET")
args.val_dir=os.getenv("VAL_DATASET")
# ── Data ──────────────────────────────────────────────────
train_loader, val_loader = build_dataloaders(
train_root = args.train_dir,
val_root = args.val_dir,
img_size = args.img_size,
batch_size = args.batch_size,
num_workers = args.num_workers,
pairs_per_epoch = args.pairs_per_epoch,
)
print(f"📦 Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
# ── Model ─────────────────────────────────────────────────
model = SiameseNet(embedding_dim=args.embedding_dim, pretrained=True).to(device)
criterion = BCECosineLoss()
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
# ── Output dir ────────────────────────────────────────────
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
best_val_loss = float("inf")
patience_counter = 0
print("\n🚀 Training started\n")
for epoch in range(1, args.epochs + 1):
t0 = time.time()
tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, train=True)
va_loss, va_acc = run_epoch(model, val_loader, criterion, None, device, train=False)
scheduler.step()
history["train_loss"].append(tr_loss)
history["val_loss"].append(va_loss)
history["train_acc"].append(tr_acc)
history["val_acc"].append(va_acc)
elapsed = time.time() - t0
print(
f"Epoch {epoch:3d}/{args.epochs}"
f"Train loss={tr_loss:.4f} acc={tr_acc:.3f}"
f"Val loss={va_loss:.4f} acc={va_acc:.3f}"
f"LR={scheduler.get_last_lr()[0]:.2e} [{elapsed:.1f}s]"
)
# ── Checkpoint ────────────────────────────────────────
if va_loss < best_val_loss:
best_val_loss = va_loss
patience_counter = 0
ckpt_path = out / "best_model.pth"
torch.save({
"epoch": epoch,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict(),
"val_loss": va_loss,
"val_acc": va_acc,
"args": vars(args),
}, ckpt_path)
print(f" ✅ Best model saved → {ckpt_path}")
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f"\n⏹ Early stopping after {epoch} epochs (patience={args.patience})")
break
# Refresh pairs each epoch (random re-pairing)
if hasattr(train_loader.dataset, "on_epoch_end"):
train_loader.dataset.on_epoch_end()
# ── Plot training curves ───────────────────────────────────
_plot_history(history, out / "training_curves.png")
print(f"\n📊 Training curves saved → {out / 'training_curves.png'}")
print(f"🏆 Best val loss: {best_val_loss:.4f}")
def _plot_history(history, save_path):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
epochs = range(1, len(history["train_loss"]) + 1)
ax1.plot(epochs, history["train_loss"], label="Train")
ax1.plot(epochs, history["val_loss"], label="Val")
ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(True)
ax2.plot(epochs, history["train_acc"], label="Train")
ax2.plot(epochs, history["val_acc"], label="Val")
ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.set_ylim(0, 1)
ax2.legend(); ax2.grid(True)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
# ─────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Train Siamese Neural Network")
p.add_argument("--train_dir", default=None, help="Path to training data root")
p.add_argument("--val_dir", default=None, help="Path to validation data root")
p.add_argument("--output_dir", default="checkpoints")
p.add_argument("--img_size", type=int, default=224)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--epochs", type=int, default=20)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--embedding_dim", type=int, default=128)
p.add_argument("--pairs_per_epoch", type=int, default=5000)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--patience", type=int, default=5,
help="Early stopping patience (epochs)")
args = p.parse_args()
train(args)