Add image processing and downloading scripts for DINOv2 model
- Introduced `dino_image_matching.py` for generating image embeddings and calculating cosine similarity. - Added `download_images.py` to download images from URLs specified in an Excel file, with support for parallel downloads and retries. - Created `generate_pairs_csv.py` to generate a CSV file of image pairs for training, including positive and negative pairs based on image hashes.AMB_DEV
parent
c66d3b1029
commit
0694f9162a
|
|
@ -0,0 +1,54 @@
|
|||
import warnings
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Optional dependency warnings from DINOv2 internals are non-critical.
|
||||
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
|
||||
|
||||
# Load model
|
||||
model = torch.hub.load(
|
||||
'facebookresearch/dinov2',
|
||||
'dinov2_vitb14'
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
|
||||
# Image preprocessing
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((518, 518)), # DINOv2 recommended size
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
def get_embedding(image_path):
|
||||
# Load image
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Transform
|
||||
tensor = transform(image).unsqueeze(0).to(device)
|
||||
|
||||
# Generate embedding
|
||||
with torch.no_grad():
|
||||
embedding = model(tensor)
|
||||
|
||||
# Normalize embedding (important for cosine similarity)
|
||||
embedding = F.normalize(embedding, p=2, dim=1)
|
||||
|
||||
return embedding.cpu()
|
||||
|
||||
# Example
|
||||
emb1 = get_embedding(r"data_images\B0B39FFJHF\03.jpg")
|
||||
emb2 = get_embedding(r"data_images\B09RWY127Q\03.jpg")
|
||||
|
||||
# Cosine similarity
|
||||
similarity = torch.nn.functional.pdist(
|
||||
torch.cat([emb1, emb2])
|
||||
)
|
||||
|
||||
print("Distance:", similarity.item())
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
import argparse
|
||||
import re
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def sanitize_name(value: str, fallback: str) -> str:
|
||||
cleaned = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", str(value)).strip()
|
||||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||
return cleaned[:120] if cleaned else fallback
|
||||
|
||||
|
||||
def split_urls(raw_value: object) -> list[str]:
|
||||
if pd.isna(raw_value):
|
||||
return []
|
||||
text = str(raw_value).strip()
|
||||
if not text:
|
||||
return []
|
||||
return [u.strip() for u in text.split(";") if u.strip()]
|
||||
|
||||
|
||||
def file_extension_from_url(url: str) -> str:
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
path = parsed.path or ""
|
||||
suffix = Path(path).suffix.lower()
|
||||
if suffix in {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}:
|
||||
return suffix
|
||||
return ".jpg"
|
||||
|
||||
|
||||
def download_file(url: str, destination: Path, timeout: int, retries: int) -> tuple[bool, str]:
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=timeout) as response:
|
||||
if response.status != 200:
|
||||
raise urllib.error.HTTPError(
|
||||
url=url,
|
||||
code=response.status,
|
||||
msg=f"HTTP status {response.status}",
|
||||
hdrs=response.headers,
|
||||
fp=None,
|
||||
)
|
||||
data = response.read()
|
||||
destination.write_bytes(data)
|
||||
return True, f"saved -> {destination}"
|
||||
except Exception as exc:
|
||||
if attempt == retries:
|
||||
return False, f"failed -> {url} ({exc})"
|
||||
return False, f"failed -> {url}"
|
||||
|
||||
|
||||
def resolve_id(row: pd.Series, row_index: int, id_column: str | None) -> str:
|
||||
if id_column and id_column in row and pd.notna(row[id_column]):
|
||||
return sanitize_name(str(row[id_column]), f"row_{row_index}")
|
||||
return f"row_{row_index}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download images from semicolon-separated URL column in an Excel sheet."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
default="listing_data.xlsx",
|
||||
help="Path to the input Excel file. Default: listing_data.xlsx",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="downloaded_images",
|
||||
help="Base output directory for downloaded images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-column",
|
||||
default="Image",
|
||||
help="Column containing semicolon-separated image URLs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id-column",
|
||||
default="ASIN",
|
||||
help="Column used to name per-row folders. Use empty string to disable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=20,
|
||||
help="HTTP timeout in seconds per request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retries",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Retry count for failed downloads.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Parallel download worker count.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
output_dir = Path(args.output_dir)
|
||||
id_column = args.id_column.strip() or None
|
||||
|
||||
if not input_path.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {input_path}")
|
||||
|
||||
df = pd.read_excel(input_path)
|
||||
if args.image_column not in df.columns:
|
||||
raise ValueError(
|
||||
f"Image column '{args.image_column}' not found. Available columns: {list(df.columns)}"
|
||||
)
|
||||
|
||||
tasks = []
|
||||
for idx, row in df.iterrows():
|
||||
row_id = resolve_id(row, idx, id_column)
|
||||
row_urls = split_urls(row[args.image_column])
|
||||
for image_index, url in enumerate(row_urls, start=1):
|
||||
ext = file_extension_from_url(url)
|
||||
filename = f"{image_index:02d}{ext}"
|
||||
destination = output_dir / row_id / filename
|
||||
tasks.append((url, destination))
|
||||
|
||||
if not tasks:
|
||||
print("No image URLs found. Nothing to download.")
|
||||
return
|
||||
|
||||
print(f"Queued {len(tasks)} image(s) from {len(df)} row(s).")
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max(args.workers, 1)) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file, url, destination, args.timeout, args.retries)
|
||||
for url, destination in tasks
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
ok, message = future.result()
|
||||
if ok:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
print(message)
|
||||
|
||||
print(f"\nDone. Success: {success_count}, Failed: {fail_count}")
|
||||
print(f"Images saved under: {output_dir.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
import argparse
|
||||
import csv
|
||||
import hashlib
|
||||
import random
|
||||
from itertools import combinations
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||||
|
||||
|
||||
def sha1_of_file(path: Path) -> str:
|
||||
hasher = hashlib.sha1()
|
||||
with path.open("rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def collect_images(root: Path) -> list[Path]:
|
||||
return [
|
||||
p for p in root.rglob("*")
|
||||
if p.is_file() and p.suffix.lower() in IMAGE_EXTS
|
||||
]
|
||||
|
||||
|
||||
def build_positive_pairs(groups: dict[str, list[Path]]) -> list[tuple[Path, Path, int]]:
|
||||
pairs: list[tuple[Path, Path, int]] = []
|
||||
for paths in groups.values():
|
||||
if len(paths) < 2:
|
||||
continue
|
||||
for a, b in combinations(paths, 2):
|
||||
pairs.append((a, b, 1))
|
||||
return pairs
|
||||
|
||||
|
||||
def build_negative_pairs(
|
||||
groups: dict[str, list[Path]],
|
||||
target_count: int,
|
||||
rng: random.Random,
|
||||
) -> list[tuple[Path, Path, int]]:
|
||||
keys = list(groups.keys())
|
||||
if len(keys) < 2:
|
||||
return []
|
||||
|
||||
pairs: list[tuple[Path, Path, int]] = []
|
||||
seen = set()
|
||||
max_attempts = target_count * 20 if target_count > 0 else 0
|
||||
attempts = 0
|
||||
|
||||
while len(pairs) < target_count and attempts < max_attempts:
|
||||
attempts += 1
|
||||
k1, k2 = rng.sample(keys, 2)
|
||||
p1 = rng.choice(groups[k1])
|
||||
p2 = rng.choice(groups[k2])
|
||||
|
||||
key = tuple(sorted((str(p1), str(p2))))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
pairs.append((p1, p2, 0))
|
||||
return pairs
|
||||
|
||||
|
||||
def write_csv(rows: list[tuple[Path, Path, int]], output: Path, base: Path) -> None:
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["image_path_1", "image_path_2", "label"])
|
||||
for p1, p2, label in rows:
|
||||
writer.writerow([
|
||||
str(p1.relative_to(base)).replace("\\", "/"),
|
||||
str(p2.relative_to(base)).replace("\\", "/"),
|
||||
label,
|
||||
])
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create pair dataset CSV: label 1 for same image, else 0."
|
||||
)
|
||||
parser.add_argument("--root", default="data_images", help="Root folder containing images")
|
||||
parser.add_argument("--output", default="pairs_dataset.csv", help="Output CSV path")
|
||||
parser.add_argument(
|
||||
"--neg_ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Number of negative pairs per positive pair (default: 1.0)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
root = Path(args.root).resolve()
|
||||
output = Path(args.output).resolve()
|
||||
rng = random.Random(args.seed)
|
||||
|
||||
images = collect_images(root)
|
||||
if not images:
|
||||
raise SystemExit(f"No images found under: {root}")
|
||||
|
||||
hash_groups: dict[str, list[Path]] = {}
|
||||
for img in images:
|
||||
file_hash = sha1_of_file(img)
|
||||
hash_groups.setdefault(file_hash, []).append(img)
|
||||
|
||||
pos_pairs = build_positive_pairs(hash_groups)
|
||||
neg_target = int(len(pos_pairs) * args.neg_ratio)
|
||||
neg_pairs = build_negative_pairs(hash_groups, neg_target, rng)
|
||||
|
||||
all_pairs = pos_pairs + neg_pairs
|
||||
rng.shuffle(all_pairs)
|
||||
write_csv(all_pairs, output, root)
|
||||
|
||||
print(f"Images found: {len(images)}")
|
||||
print(f"Unique images by hash: {len(hash_groups)}")
|
||||
print(f"Positive pairs (label=1): {len(pos_pairs)}")
|
||||
print(f"Negative pairs (label=0): {len(neg_pairs)}")
|
||||
print(f"Total pairs written: {len(all_pairs)}")
|
||||
print(f"CSV: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue