127 lines
3.7 KiB
Python
127 lines
3.7 KiB
Python
import argparse
|
|
import csv
|
|
import hashlib
|
|
import random
|
|
from itertools import combinations
|
|
from pathlib import Path
|
|
|
|
|
|
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
|
|
|
|
|
def sha1_of_file(path: Path) -> str:
|
|
hasher = hashlib.sha1()
|
|
with path.open("rb") as f:
|
|
while True:
|
|
chunk = f.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
hasher.update(chunk)
|
|
return hasher.hexdigest()
|
|
|
|
|
|
def collect_images(root: Path) -> list[Path]:
|
|
return [
|
|
p for p in root.rglob("*")
|
|
if p.is_file() and p.suffix.lower() in IMAGE_EXTS
|
|
]
|
|
|
|
|
|
def build_positive_pairs(groups: dict[str, list[Path]]) -> list[tuple[Path, Path, int]]:
|
|
pairs: list[tuple[Path, Path, int]] = []
|
|
for paths in groups.values():
|
|
if len(paths) < 2:
|
|
continue
|
|
for a, b in combinations(paths, 2):
|
|
pairs.append((a, b, 1))
|
|
return pairs
|
|
|
|
|
|
def build_negative_pairs(
|
|
groups: dict[str, list[Path]],
|
|
target_count: int,
|
|
rng: random.Random,
|
|
) -> list[tuple[Path, Path, int]]:
|
|
keys = list(groups.keys())
|
|
if len(keys) < 2:
|
|
return []
|
|
|
|
pairs: list[tuple[Path, Path, int]] = []
|
|
seen = set()
|
|
max_attempts = target_count * 20 if target_count > 0 else 0
|
|
attempts = 0
|
|
|
|
while len(pairs) < target_count and attempts < max_attempts:
|
|
attempts += 1
|
|
k1, k2 = rng.sample(keys, 2)
|
|
p1 = rng.choice(groups[k1])
|
|
p2 = rng.choice(groups[k2])
|
|
|
|
key = tuple(sorted((str(p1), str(p2))))
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
pairs.append((p1, p2, 0))
|
|
return pairs
|
|
|
|
|
|
def write_csv(rows: list[tuple[Path, Path, int]], output: Path, base: Path) -> None:
|
|
output.parent.mkdir(parents=True, exist_ok=True)
|
|
with output.open("w", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(["image_path_1", "image_path_2", "label"])
|
|
for p1, p2, label in rows:
|
|
writer.writerow([
|
|
str(p1.relative_to(base)).replace("\\", "/"),
|
|
str(p2.relative_to(base)).replace("\\", "/"),
|
|
label,
|
|
])
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Create pair dataset CSV: label 1 for same image, else 0."
|
|
)
|
|
parser.add_argument("--root", default="data_images", help="Root folder containing images")
|
|
parser.add_argument("--output", default="pairs_dataset.csv", help="Output CSV path")
|
|
parser.add_argument(
|
|
"--neg_ratio",
|
|
type=float,
|
|
default=1.0,
|
|
help="Number of negative pairs per positive pair (default: 1.0)",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
args = parser.parse_args()
|
|
|
|
root = Path(args.root).resolve()
|
|
output = Path(args.output).resolve()
|
|
rng = random.Random(args.seed)
|
|
|
|
images = collect_images(root)
|
|
if not images:
|
|
raise SystemExit(f"No images found under: {root}")
|
|
|
|
hash_groups: dict[str, list[Path]] = {}
|
|
for img in images:
|
|
file_hash = sha1_of_file(img)
|
|
hash_groups.setdefault(file_hash, []).append(img)
|
|
|
|
pos_pairs = build_positive_pairs(hash_groups)
|
|
neg_target = int(len(pos_pairs) * args.neg_ratio)
|
|
neg_pairs = build_negative_pairs(hash_groups, neg_target, rng)
|
|
|
|
all_pairs = pos_pairs + neg_pairs
|
|
rng.shuffle(all_pairs)
|
|
write_csv(all_pairs, output, root)
|
|
|
|
print(f"Images found: {len(images)}")
|
|
print(f"Unique images by hash: {len(hash_groups)}")
|
|
print(f"Positive pairs (label=1): {len(pos_pairs)}")
|
|
print(f"Negative pairs (label=0): {len(neg_pairs)}")
|
|
print(f"Total pairs written: {len(all_pairs)}")
|
|
print(f"CSV: {output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|