From 0694f9162ab72e4aa79a976f7bee162ad5080820 Mon Sep 17 00:00:00 2001 From: "bahawal.baloch" Date: Thu, 7 May 2026 17:45:48 +0500 Subject: [PATCH] Add image processing and downloading scripts for DINOv2 model - Introduced `dino_image_matching.py` for generating image embeddings and calculating cosine similarity. - Added `download_images.py` to download images from URLs specified in an Excel file, with support for parallel downloads and retries. - Created `generate_pairs_csv.py` to generate a CSV file of image pairs for training, including positive and negative pairs based on image hashes. --- dino_image_matching.py | 54 ++++++++++++++ download_images.py | 157 +++++++++++++++++++++++++++++++++++++++++ generate_pairs_csv.py | 126 +++++++++++++++++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 dino_image_matching.py create mode 100644 download_images.py create mode 100644 generate_pairs_csv.py diff --git a/dino_image_matching.py b/dino_image_matching.py new file mode 100644 index 0000000..0005903 --- /dev/null +++ b/dino_image_matching.py @@ -0,0 +1,54 @@ +import warnings + +import torch +from PIL import Image +from torchvision import transforms +import torch.nn.functional as F + +# Optional dependency warnings from DINOv2 internals are non-critical. +warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning) + +# Load model +model = torch.hub.load( + 'facebookresearch/dinov2', + 'dinov2_vitb14' +) + +model.eval() + +# Device +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) + +# Image preprocessing +transform = transforms.Compose([ + transforms.Resize((518, 518)), # DINOv2 recommended size + transforms.ToTensor(), +]) + +def get_embedding(image_path): + # Load image + image = Image.open(image_path).convert("RGB") + + # Transform + tensor = transform(image).unsqueeze(0).to(device) + + # Generate embedding + with torch.no_grad(): + embedding = model(tensor) + + # Normalize embedding (important for cosine similarity) + embedding = F.normalize(embedding, p=2, dim=1) + + return embedding.cpu() + +# Example +emb1 = get_embedding(r"data_images\B0B39FFJHF\03.jpg") +emb2 = get_embedding(r"data_images\B09RWY127Q\03.jpg") + +# Cosine similarity +similarity = torch.nn.functional.pdist( + torch.cat([emb1, emb2]) +) + +print("Distance:", similarity.item()) \ No newline at end of file diff --git a/download_images.py b/download_images.py new file mode 100644 index 0000000..1739bf2 --- /dev/null +++ b/download_images.py @@ -0,0 +1,157 @@ +import argparse +import re +import urllib.error +import urllib.parse +import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pandas as pd + + +def sanitize_name(value: str, fallback: str) -> str: + cleaned = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", str(value)).strip() + cleaned = re.sub(r"\s+", " ", cleaned) + return cleaned[:120] if cleaned else fallback + + +def split_urls(raw_value: object) -> list[str]: + if pd.isna(raw_value): + return [] + text = str(raw_value).strip() + if not text: + return [] + return [u.strip() for u in text.split(";") if u.strip()] + + +def file_extension_from_url(url: str) -> str: + parsed = urllib.parse.urlparse(url) + path = parsed.path or "" + suffix = Path(path).suffix.lower() + if suffix in {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}: + return suffix + return ".jpg" + + +def download_file(url: str, destination: Path, timeout: int, retries: int) -> tuple[bool, str]: + destination.parent.mkdir(parents=True, exist_ok=True) + for attempt in range(retries + 1): + try: + with urllib.request.urlopen(url, timeout=timeout) as response: + if response.status != 200: + raise urllib.error.HTTPError( + url=url, + code=response.status, + msg=f"HTTP status {response.status}", + hdrs=response.headers, + fp=None, + ) + data = response.read() + destination.write_bytes(data) + return True, f"saved -> {destination}" + except Exception as exc: + if attempt == retries: + return False, f"failed -> {url} ({exc})" + return False, f"failed -> {url}" + + +def resolve_id(row: pd.Series, row_index: int, id_column: str | None) -> str: + if id_column and id_column in row and pd.notna(row[id_column]): + return sanitize_name(str(row[id_column]), f"row_{row_index}") + return f"row_{row_index}" + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Download images from semicolon-separated URL column in an Excel sheet." + ) + parser.add_argument( + "--input", + default="listing_data.xlsx", + help="Path to the input Excel file. Default: listing_data.xlsx", + ) + parser.add_argument( + "--output-dir", + default="downloaded_images", + help="Base output directory for downloaded images.", + ) + parser.add_argument( + "--image-column", + default="Image", + help="Column containing semicolon-separated image URLs.", + ) + parser.add_argument( + "--id-column", + default="ASIN", + help="Column used to name per-row folders. Use empty string to disable.", + ) + parser.add_argument( + "--timeout", + type=int, + default=20, + help="HTTP timeout in seconds per request.", + ) + parser.add_argument( + "--retries", + type=int, + default=2, + help="Retry count for failed downloads.", + ) + parser.add_argument( + "--workers", + type=int, + default=10, + help="Parallel download worker count.", + ) + args = parser.parse_args() + + input_path = Path(args.input) + output_dir = Path(args.output_dir) + id_column = args.id_column.strip() or None + + if not input_path.exists(): + raise FileNotFoundError(f"Input file not found: {input_path}") + + df = pd.read_excel(input_path) + if args.image_column not in df.columns: + raise ValueError( + f"Image column '{args.image_column}' not found. Available columns: {list(df.columns)}" + ) + + tasks = [] + for idx, row in df.iterrows(): + row_id = resolve_id(row, idx, id_column) + row_urls = split_urls(row[args.image_column]) + for image_index, url in enumerate(row_urls, start=1): + ext = file_extension_from_url(url) + filename = f"{image_index:02d}{ext}" + destination = output_dir / row_id / filename + tasks.append((url, destination)) + + if not tasks: + print("No image URLs found. Nothing to download.") + return + + print(f"Queued {len(tasks)} image(s) from {len(df)} row(s).") + success_count = 0 + fail_count = 0 + + with ThreadPoolExecutor(max_workers=max(args.workers, 1)) as executor: + futures = [ + executor.submit(download_file, url, destination, args.timeout, args.retries) + for url, destination in tasks + ] + for future in as_completed(futures): + ok, message = future.result() + if ok: + success_count += 1 + else: + fail_count += 1 + print(message) + + print(f"\nDone. Success: {success_count}, Failed: {fail_count}") + print(f"Images saved under: {output_dir.resolve()}") + + +if __name__ == "__main__": + main() diff --git a/generate_pairs_csv.py b/generate_pairs_csv.py new file mode 100644 index 0000000..d8af65b --- /dev/null +++ b/generate_pairs_csv.py @@ -0,0 +1,126 @@ +import argparse +import csv +import hashlib +import random +from itertools import combinations +from pathlib import Path + + +IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} + + +def sha1_of_file(path: Path) -> str: + hasher = hashlib.sha1() + with path.open("rb") as f: + while True: + chunk = f.read(1024 * 1024) + if not chunk: + break + hasher.update(chunk) + return hasher.hexdigest() + + +def collect_images(root: Path) -> list[Path]: + return [ + p for p in root.rglob("*") + if p.is_file() and p.suffix.lower() in IMAGE_EXTS + ] + + +def build_positive_pairs(groups: dict[str, list[Path]]) -> list[tuple[Path, Path, int]]: + pairs: list[tuple[Path, Path, int]] = [] + for paths in groups.values(): + if len(paths) < 2: + continue + for a, b in combinations(paths, 2): + pairs.append((a, b, 1)) + return pairs + + +def build_negative_pairs( + groups: dict[str, list[Path]], + target_count: int, + rng: random.Random, +) -> list[tuple[Path, Path, int]]: + keys = list(groups.keys()) + if len(keys) < 2: + return [] + + pairs: list[tuple[Path, Path, int]] = [] + seen = set() + max_attempts = target_count * 20 if target_count > 0 else 0 + attempts = 0 + + while len(pairs) < target_count and attempts < max_attempts: + attempts += 1 + k1, k2 = rng.sample(keys, 2) + p1 = rng.choice(groups[k1]) + p2 = rng.choice(groups[k2]) + + key = tuple(sorted((str(p1), str(p2)))) + if key in seen: + continue + seen.add(key) + pairs.append((p1, p2, 0)) + return pairs + + +def write_csv(rows: list[tuple[Path, Path, int]], output: Path, base: Path) -> None: + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["image_path_1", "image_path_2", "label"]) + for p1, p2, label in rows: + writer.writerow([ + str(p1.relative_to(base)).replace("\\", "/"), + str(p2.relative_to(base)).replace("\\", "/"), + label, + ]) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Create pair dataset CSV: label 1 for same image, else 0." + ) + parser.add_argument("--root", default="data_images", help="Root folder containing images") + parser.add_argument("--output", default="pairs_dataset.csv", help="Output CSV path") + parser.add_argument( + "--neg_ratio", + type=float, + default=1.0, + help="Number of negative pairs per positive pair (default: 1.0)", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + root = Path(args.root).resolve() + output = Path(args.output).resolve() + rng = random.Random(args.seed) + + images = collect_images(root) + if not images: + raise SystemExit(f"No images found under: {root}") + + hash_groups: dict[str, list[Path]] = {} + for img in images: + file_hash = sha1_of_file(img) + hash_groups.setdefault(file_hash, []).append(img) + + pos_pairs = build_positive_pairs(hash_groups) + neg_target = int(len(pos_pairs) * args.neg_ratio) + neg_pairs = build_negative_pairs(hash_groups, neg_target, rng) + + all_pairs = pos_pairs + neg_pairs + rng.shuffle(all_pairs) + write_csv(all_pairs, output, root) + + print(f"Images found: {len(images)}") + print(f"Unique images by hash: {len(hash_groups)}") + print(f"Positive pairs (label=1): {len(pos_pairs)}") + print(f"Negative pairs (label=0): {len(neg_pairs)}") + print(f"Total pairs written: {len(all_pairs)}") + print(f"CSV: {output}") + + +if __name__ == "__main__": + main()