diff --git a/.gitignore b/.gitignore index 1af7e7b..efe7b25 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ **data** *pyc** **pycache** -*agent/** \ No newline at end of file +*agent/** +**downloaded_images** \ No newline at end of file diff --git a/dev_backend/.env b/dev_backend/.env deleted file mode 100644 index 7d6e8e1..0000000 --- a/dev_backend/.env +++ /dev/null @@ -1,9 +0,0 @@ -MYSQL_HOST=localhost - -MYSQL_PORT=3306 - -MYSQL_USER=root - -MYSQL_PASSWORD='AmB@ig123' - -MYSQL_DATABASE=listing_radar diff --git a/dev_backend/__pycache__/main.cpython-313.pyc b/dev_backend/__pycache__/main.cpython-313.pyc index 9a0c7ae..a33d8a7 100644 Binary files a/dev_backend/__pycache__/main.cpython-313.pyc and b/dev_backend/__pycache__/main.cpython-313.pyc differ diff --git a/dev_backend/main.py b/dev_backend/main.py index 5047c25..b6b5484 100644 --- a/dev_backend/main.py +++ b/dev_backend/main.py @@ -1,3 +1,9 @@ +import sys +import os + +# Add the project root to sys.path to allow imports from model_export +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from fastapi import FastAPI, status from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from dotenv import load_dotenv diff --git a/dev_backend/vector_db_router/models.py b/dev_backend/vector_db_router/models.py index e095e92..16cac52 100644 --- a/dev_backend/vector_db_router/models.py +++ b/dev_backend/vector_db_router/models.py @@ -3,13 +3,24 @@ from qdrant_client.models import PointStruct from typing import Dict, Any class CollectionHandler: - def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, id: int): + def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, + id: int=None, + link: str=None, + asin: str=None, + category: str=None, + brand: str=None, + client: AsyncQdrantClient=None + ): self.collection_name = collection_name self.vector = vector + self.id = id self.vector_size = vector_size self.payload = payload - self.id = id - self.client = AsyncQdrantClient("localhost", port=6333) + self.link = link + self.asin = asin + self.category = category + self.brand = brand + self.client = client if client else AsyncQdrantClient("localhost", port=6333) async def create_collection(self): try: @@ -18,21 +29,34 @@ class CollectionHandler: await self.client.create_collection( collection_name=self.collection_name, - vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.COSINE), + vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID), optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000) ) # Creating payload indexes as per project logic + await self.client.create_payload_index( collection_name=self.collection_name, - field_name="Product_ID", + field_name="link", field_schema=models.PayloadSchemaType.KEYWORD ) await self.client.create_payload_index( collection_name=self.collection_name, - field_name="Product_Link", + field_name="title", field_schema=models.PayloadSchemaType.KEYWORD ) + + await self.client.create_payload_index( + collection_name=self.collection_name, + field_name="brand", + field_schema=models.PayloadSchemaType.KEYWORD + ) + await self.client.create_payload_index( + collection_name=self.collection_name, + field_name="asin", + field_schema=models.PayloadSchemaType.KEYWORD + ) + return {"message": f"Collection {self.collection_name} created successfully"} except Exception as e: @@ -46,10 +70,10 @@ class CollectionHandler: PointStruct(id=self.id, vector=self.vector, payload=self.payload) ] ) - print("Data inserted successfully") return True except Exception as e: - print("Insertion failed: ", e) + # Note: In a real app we'd use a logger here + print(f"Insertion failed for ID {self.id}: {e}") return False async def upsert_point(self): diff --git a/dev_backend/vector_db_router/views.py b/dev_backend/vector_db_router/views.py index f768947..b87275e 100644 --- a/dev_backend/vector_db_router/views.py +++ b/dev_backend/vector_db_router/views.py @@ -9,9 +9,92 @@ from .serializers import ( UpdateCollectionSerializer, DeleteCollectionSerializer ) +from model_export.dino_image_matching import get_vectors from .models import CollectionHandler - +import os app_router = APIRouter() +import logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +import pandas as pd + +@app_router.get("/get_vectors") +async def get_vectors_endpoint( + q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], + image_path:str=os.getenv("DATASET") +): + try: + # Construct path relative to this file + base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx") + df = pd.read_excel(excel_path) + asin_list = df['ASIN'].dropna().astype(str).tolist() + log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}") + + # 1. Initialize/Create collection "Product" + # DINOv2 vitb14 vector size is 768 + init_handler = CollectionHandler( + collection_name="Product", + vector=[], + vector_size=768, + payload={}, + client=q + ) + await init_handler.create_collection() + + result_lst = [] + for index, row in df.iterrows(): + asin = str(row['ASIN']) + if pd.isna(row['ASIN']): + continue + + title = str(row['Title']) + brand = str(row['Brand']) + link = str(row['Image']) + + # Call get_vectors + vector = get_vectors(image_input=image_path, item=asin) + + if vector is None: + log.warning(f"Skipping {asin} due to missing image/vector") + continue + + payload = { + "asin": asin, + "title": title, + "brand": brand, + "link": link + } + + # 2. Ingest into Qdrant using CollectionHandler + # Use the injected client 'q' and convert index to int + handler = CollectionHandler( + collection_name="Product", + vector=vector, + vector_size=768, + payload=payload, + id=int(index), + link=link, + asin=asin, + brand=brand, + client=q + ) + success = await handler.upsert_point() + + if success: + result_lst.append({ + "item": asin, + "status": "ingested", + "payload": payload + }) + log.info(f"Vector ingested for {asin} (ID: {index})") + else: + log.error(f"Failed to ingest vector for {asin}") + + return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @app_router.post("/create") async def create_collection_endpoint( diff --git a/model_export/dino_image_matching.py b/model_export/dino_image_matching.py index 0005903..54c369f 100644 --- a/model_export/dino_image_matching.py +++ b/model_export/dino_image_matching.py @@ -1,9 +1,15 @@ import warnings - -import torch +from dotenv import load_dotenv +import torch,glob from PIL import Image from torchvision import transforms +import os +from qdrant_client import AsyncQdrantClient, models +from db_setup import get_qdrant_client +from vector_db_router.models import CollectionHandler +from vector_db_router.serializers import CreateCollectionSerializer import torch.nn.functional as F +load_dotenv() # Optional dependency warnings from DINOv2 internals are non-critical. warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning) @@ -42,13 +48,25 @@ def get_embedding(image_path): return embedding.cpu() -# Example -emb1 = get_embedding(r"data_images\B0B39FFJHF\03.jpg") -emb2 = get_embedding(r"data_images\B09RWY127Q\03.jpg") - -# Cosine similarity -similarity = torch.nn.functional.pdist( - torch.cat([emb1, emb2]) -) - -print("Distance:", similarity.item()) \ No newline at end of file +def get_vectors(image_input, item): + try: + base_dir = os.path.join(os.path.dirname(__file__), "downloaded_images") + path = image_input + + # If image_input is not a valid file, try to find one using the item (ASIN) + if not (path and os.path.isfile(path)): + # If path is a directory, use it as base_dir + search_base = path if (path and os.path.isdir(path)) else base_dir + glob_pattern = os.path.join(search_base, item, "*.jpg") + jpg_files = glob.glob(glob_pattern) + if jpg_files: + path = jpg_files[0] + else: + return None + + # Generate the vector for the identified image file + emb = get_embedding(path) + return emb.squeeze().tolist() + except Exception as e: + print(f"Error generating vector for {item}: {e}") + return None \ No newline at end of file diff --git a/req.txt b/req.txt new file mode 100644 index 0000000..bd27ef8 --- /dev/null +++ b/req.txt @@ -0,0 +1,76 @@ +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.13.0 +asyncmy==0.2.11 +certifi==2026.4.22 +click==8.3.3 +contourpy==1.3.3 +cuda-bindings==13.2.0 +cuda-pathfinder==1.5.4 +cuda-toolkit==13.0.2 +cycler==0.12.1 +et_xmlfile==2.0.0 +fastapi==0.136.1 +filelock==3.29.0 +fonttools==4.62.1 +fsspec==2026.4.0 +greenlet==3.5.0 +grpcio==1.80.0 +h11==0.16.0 +h2==4.3.0 +hpack==4.1.0 +httpcore==1.0.9 +httpx==0.28.1 +hyperframe==6.1.0 +idna==3.13 +Jinja2==3.1.6 +joblib==1.5.3 +kiwisolver==1.5.0 +MarkupSafe==3.0.3 +matplotlib==3.10.9 +mpmath==1.3.0 +networkx==3.6.1 +numpy==2.4.4 +nvidia-cublas==13.1.0.3 +nvidia-cuda-cupti==13.0.85 +nvidia-cuda-nvrtc==13.0.88 +nvidia-cuda-runtime==13.0.96 +nvidia-cudnn-cu13==9.19.0.56 +nvidia-cufft==12.0.0.61 +nvidia-cufile==1.15.1.6 +nvidia-curand==10.4.0.35 +nvidia-cusolver==12.0.4.66 +nvidia-cusparse==12.6.3.3 +nvidia-cusparselt-cu13==0.8.0 +nvidia-nccl-cu13==2.28.9 +nvidia-nvjitlink==13.0.88 +nvidia-nvshmem-cu13==3.4.5 +nvidia-nvtx==13.0.85 +openpyxl==3.1.5 +packaging==26.2 +pandas==3.0.2 +pillow==12.2.0 +portalocker==3.2.0 +protobuf==7.34.1 +pydantic==2.13.4 +pydantic_core==2.46.4 +pyparsing==3.3.2 +python-dateutil==2.9.0.post0 +python-dotenv==1.2.2 +qdrant-client==1.17.1 +scikit-learn==1.8.0 +scipy==1.17.1 +setuptools==81.0.0 +six==1.17.0 +SQLAlchemy==2.0.49 +starlette==1.0.0 +sympy==1.14.0 +threadpoolctl==3.6.0 +torch==2.11.0 +torchvision==0.26.0 +tqdm==4.67.3 +triton==3.6.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +urllib3==2.6.3 +uvicorn==0.46.0