import argparse import csv import hashlib import random from itertools import combinations from pathlib import Path IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} def sha1_of_file(path: Path) -> str: hasher = hashlib.sha1() with path.open("rb") as f: while True: chunk = f.read(1024 * 1024) if not chunk: break hasher.update(chunk) return hasher.hexdigest() def collect_images(root: Path) -> list[Path]: return [ p for p in root.rglob("*") if p.is_file() and p.suffix.lower() in IMAGE_EXTS ] def build_positive_pairs(groups: dict[str, list[Path]]) -> list[tuple[Path, Path, int]]: pairs: list[tuple[Path, Path, int]] = [] for paths in groups.values(): if len(paths) < 2: continue for a, b in combinations(paths, 2): pairs.append((a, b, 1)) return pairs def build_negative_pairs( groups: dict[str, list[Path]], target_count: int, rng: random.Random, ) -> list[tuple[Path, Path, int]]: keys = list(groups.keys()) if len(keys) < 2: return [] pairs: list[tuple[Path, Path, int]] = [] seen = set() max_attempts = target_count * 20 if target_count > 0 else 0 attempts = 0 while len(pairs) < target_count and attempts < max_attempts: attempts += 1 k1, k2 = rng.sample(keys, 2) p1 = rng.choice(groups[k1]) p2 = rng.choice(groups[k2]) key = tuple(sorted((str(p1), str(p2)))) if key in seen: continue seen.add(key) pairs.append((p1, p2, 0)) return pairs def write_csv(rows: list[tuple[Path, Path, int]], output: Path, base: Path) -> None: output.parent.mkdir(parents=True, exist_ok=True) with output.open("w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["image_path_1", "image_path_2", "label"]) for p1, p2, label in rows: writer.writerow([ str(p1.relative_to(base)).replace("\\", "/"), str(p2.relative_to(base)).replace("\\", "/"), label, ]) def main() -> None: parser = argparse.ArgumentParser( description="Create pair dataset CSV: label 1 for same image, else 0." ) parser.add_argument("--root", default="data_images", help="Root folder containing images") parser.add_argument("--output", default="pairs_dataset.csv", help="Output CSV path") parser.add_argument( "--neg_ratio", type=float, default=1.0, help="Number of negative pairs per positive pair (default: 1.0)", ) parser.add_argument("--seed", type=int, default=42, help="Random seed") args = parser.parse_args() root = Path(args.root).resolve() output = Path(args.output).resolve() rng = random.Random(args.seed) images = collect_images(root) if not images: raise SystemExit(f"No images found under: {root}") hash_groups: dict[str, list[Path]] = {} for img in images: file_hash = sha1_of_file(img) hash_groups.setdefault(file_hash, []).append(img) pos_pairs = build_positive_pairs(hash_groups) neg_target = int(len(pos_pairs) * args.neg_ratio) neg_pairs = build_negative_pairs(hash_groups, neg_target, rng) all_pairs = pos_pairs + neg_pairs rng.shuffle(all_pairs) write_csv(all_pairs, output, root) print(f"Images found: {len(images)}") print(f"Unique images by hash: {len(hash_groups)}") print(f"Positive pairs (label=1): {len(pos_pairs)}") print(f"Negative pairs (label=0): {len(neg_pairs)}") print(f"Total pairs written: {len(all_pairs)}") print(f"CSV: {output}") if __name__ == "__main__": main()