72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
import warnings
|
|
from dotenv import load_dotenv
|
|
import torch,glob
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
import os
|
|
from qdrant_client import AsyncQdrantClient, models
|
|
from db_setup import get_qdrant_client
|
|
from vector_db_router.models import CollectionHandler
|
|
from vector_db_router.serializers import CreateCollectionSerializer
|
|
import torch.nn.functional as F
|
|
load_dotenv()
|
|
|
|
# Optional dependency warnings from DINOv2 internals are non-critical.
|
|
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
|
|
|
|
# Load model
|
|
model = torch.hub.load(
|
|
'facebookresearch/dinov2',
|
|
'dinov2_vitb14'
|
|
)
|
|
|
|
model.eval()
|
|
|
|
# Device
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = model.to(device)
|
|
|
|
# Image preprocessing
|
|
transform = transforms.Compose([
|
|
transforms.Resize((518, 518)), # DINOv2 recommended size
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
def get_embedding(image_path):
|
|
# Load image
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
# Transform
|
|
tensor = transform(image).unsqueeze(0).to(device)
|
|
|
|
# Generate embedding
|
|
with torch.no_grad():
|
|
embedding = model(tensor)
|
|
|
|
# Normalize embedding (important for cosine similarity)
|
|
embedding = F.normalize(embedding, p=2, dim=1)
|
|
|
|
return embedding.cpu()
|
|
|
|
def get_vectors(image_input, item):
|
|
try:
|
|
base_dir = os.path.join(os.path.dirname(__file__), "downloaded_images")
|
|
path = image_input
|
|
|
|
# If image_input is not a valid file, try to find one using the item (ASIN)
|
|
if not (path and os.path.isfile(path)):
|
|
# If path is a directory, use it as base_dir
|
|
search_base = path if (path and os.path.isdir(path)) else base_dir
|
|
glob_pattern = os.path.join(search_base, item, "*.jpg")
|
|
jpg_files = glob.glob(glob_pattern)
|
|
if jpg_files:
|
|
path = jpg_files[0]
|
|
else:
|
|
return None
|
|
|
|
# Generate the vector for the identified image file
|
|
emb = get_embedding(path)
|
|
return emb.squeeze().tolist()
|
|
except Exception as e:
|
|
print(f"Error generating vector for {item}: {e}")
|
|
return None |