diff --git a/.gitignore b/.gitignore index 5cd4295..13ffc51 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,5 @@ **pycache** *agent/** **downloaded_images** -__pycache__/ -*.pyc \ No newline at end of file +**model_export** +**cpython** \ No newline at end of file diff --git a/dev_backend/vector_db_router/models.py b/dev_backend/vector_db_router/models.py index c011d98..a2abec8 100644 --- a/dev_backend/vector_db_router/models.py +++ b/dev_backend/vector_db_router/models.py @@ -1,104 +1,108 @@ -from qdrant_client import AsyncQdrantClient, models -from qdrant_client.models import PointStruct -from typing import Dict, Any - -class CollectionHandler: - def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, - id: int=None, - link: str=None, - asin: str=None, - category: str=None, - brand: str=None, - client: AsyncQdrantClient=None - ): - self.collection_name = collection_name - self.vector = vector - self.id = id - self.vector_size = vector_size - self.payload = payload - self.link = link - self.asin = asin - self.category = category - self.brand = brand - self.client = client if client else AsyncQdrantClient("localhost", port=6333) - - async def create_collection(self): - try: - if await self.client.collection_exists(self.collection_name): - return {"message": "Collection already exists"} - - await self.client.create_collection( - collection_name=self.collection_name, - vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID), - optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000) - ) - - # Creating payload indexes as per project logic - - await self.client.create_payload_index( - collection_name=self.collection_name, - field_name="link", - field_schema=models.PayloadSchemaType.KEYWORD - ) - await self.client.create_payload_index( - collection_name=self.collection_name, - field_name="title", - field_schema=models.PayloadSchemaType.KEYWORD - ) - - await self.client.create_payload_index( - collection_name=self.collection_name, - field_name="brand", - field_schema=models.PayloadSchemaType.KEYWORD - ) - await self.client.create_payload_index( - collection_name=self.collection_name, - field_name="asin", - field_schema=models.PayloadSchemaType.KEYWORD - ) - - - return {"message": f"Collection {self.collection_name} created successfully"} - except Exception as e: - return {"message": str(e)} - - async def insertion(self): - try: - await self.client.upsert( - collection_name=self.collection_name, - points=[ - PointStruct(id=self.id, vector=self.vector, payload=self.payload) - ] - ) - return True - except Exception as e: - # Note: In a real app we'd use a logger here - print(f"Insertion failed for ID {self.id}: {e}") - return False - - async def upsert_point(self): - return await self.insertion() - - async def search(self, query_vector): - try: - result = await self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - limit=10 - ) - return result - except Exception as e: - print("Search failed: ", e) - return None - - async def update_collection(self): - """Update is implemented as an upsert of the point data.""" - return await self.upsert_point() - - async def delete_collection(self): - try: - await self.client.delete_collection(collection_name=self.collection_name) - return {"message": f"Collection {self.collection_name} deleted successfully"} - except Exception as e: - return {"message": str(e)} - +from qdrant_client import AsyncQdrantClient, models +from qdrant_client.models import PointStruct +from typing import Dict, Any +import logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +class CollectionHandler: + def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, + id: int=None, + link: str=None, + asin: str=None, + category: str=None, + brand: str=None, + client: AsyncQdrantClient=None + ): + self.collection_name = collection_name + self.vector = vector + self.id = id + self.vector_size = vector_size + self.payload = payload + self.link = link + self.asin = asin + self.category = category + self.brand = brand + self.client = client if client else AsyncQdrantClient("localhost", port=6333) + + async def create_collection(self): + try: + if await self.client.collection_exists(self.collection_name): + return {"message": "Collection already exists"} + + await self.client.create_collection( + collection_name=self.collection_name, + vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID), + optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000) + ) + + # Creating payload indexes as per project logic + + await self.client.create_payload_index( + collection_name=self.collection_name, + field_name="link", + field_schema=models.PayloadSchemaType.KEYWORD + ) + await self.client.create_payload_index( + collection_name=self.collection_name, + field_name="title", + field_schema=models.PayloadSchemaType.KEYWORD + ) + + await self.client.create_payload_index( + collection_name=self.collection_name, + field_name="brand", + field_schema=models.PayloadSchemaType.KEYWORD + ) + await self.client.create_payload_index( + collection_name=self.collection_name, + field_name="asin", + field_schema=models.PayloadSchemaType.KEYWORD + ) + + + return {"message": f"Collection {self.collection_name} created successfully"} + except Exception as e: + return {"message": str(e)} + + async def insertion(self): + try: + await self.client.upsert( + collection_name=self.collection_name, + points=[ + PointStruct(id=self.id, vector=self.vector, payload=self.payload) + ] + ) + return True + except Exception as e: + # Note: In a real app we'd use a logger here + print(f"Insertion failed for ID {self.id}: {e}") + return False + + async def upsert_point(self): + return await self.insertion() + + async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10): + try: + result = await self.client.query_points( + collection_name=self.collection_name, + query=query_vector, + score_threshold=score_threshold, + limit=limit + ) + return result.points + except Exception as e: + log.error(f"Search failed: {e}") + return None + + async def update_collection(self): + """Update is implemented as an upsert of the point data.""" + return await self.upsert_point() + + async def delete_collection(self): + try: + await self.client.delete_collection(collection_name=self.collection_name) + return {"message": f"Collection {self.collection_name} deleted successfully"} + except Exception as e: + return {"message": str(e)} + diff --git a/dev_backend/vector_db_router/plugins.py b/dev_backend/vector_db_router/plugins.py new file mode 100644 index 0000000..a500340 --- /dev/null +++ b/dev_backend/vector_db_router/plugins.py @@ -0,0 +1,90 @@ +import os +import requests +from urllib.parse import urlparse +from pathlib import Path +from PIL import Image + +def download_image(url: str, filename: str = None) -> str: + """ + Download an image from URL and save it in data/temp/ folder. + + Args: + url (str): Image URL + filename (str, optional): Custom filename. If None, extracted from URL. + + Returns: + str: Full path to the downloaded image + """ + try: + # Get project root directory (where your main script is) + root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent + + # Create data/temp folder structure + temp_dir = root_dir / "data" / "temp" + temp_dir.mkdir(parents=True, exist_ok=True) + + # Generate filename if not provided + if not filename: + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path) + if not filename or "." not in filename: + # Fallback filename + ext = filename.split('.')[-1] if '.' in filename else 'jpg' + filename = f"image_{hash(url) % 100000}.{ext}" + + # Ensure filename has extension + if '.' not in filename: + filename += ".jpg" + + file_path = temp_dir / filename + + # Download the image + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + # Save image + with open(file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"✅ Image downloaded: {file_path}") + return str(file_path) + + except Exception as e: + print(f"❌ Failed to download image: {e}") + raise + +def read_image(image_path: str) -> Image.Image: + """ + Read an image from the given path and return a PIL Image object. + + Args: + image_path (str): Path to the image file + + Returns: + PIL.Image.Image: Loaded image + + Raises: + FileNotFoundError: If image doesn't exist + Exception: For other image loading errors + """ + try: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image not found at path: {image_path}") + + # Open the image + image = Image.open(image_path) + + # Convert to RGB (important for DINOv2 and most models) + if image.mode != "RGB": + image = image.convert("RGB") + + print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}") + return image + + except FileNotFoundError as e: + print(f"❌ File not found: {e}") + raise + except Exception as e: + print(f"❌ Failed to read image: {e}") + raise \ No newline at end of file diff --git a/dev_backend/vector_db_router/serializers.py b/dev_backend/vector_db_router/serializers.py index fa3bbb1..7682200 100644 --- a/dev_backend/vector_db_router/serializers.py +++ b/dev_backend/vector_db_router/serializers.py @@ -1,22 +1,24 @@ -from pydantic import BaseModel -from typing import Dict, List, Any - -class CreateCollectionSerializer(BaseModel): - collection_name: str - vector: List[float] - vector_size: int - payload: Dict[str, Any] - id: int - -class QueryCollectionSerializer(BaseModel): - collection_name: str - query_vector: List[float] - -class UpdateCollectionSerializer(BaseModel): - collection_name: str - vector: List[float] - payload: Dict[str, Any] - id: int - -class DeleteCollectionSerializer(BaseModel): - collection_name: str +from pydantic import BaseModel +from typing import Dict, List, Any + +class CreateCollectionSerializer(BaseModel): + collection_name: str + vector: List[float] + vector_size: int + payload: Dict[str, Any] + id: int + +class QueryCollectionSerializer(BaseModel): + collection_name: str + url: str + score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match + limit: int = 10 + +class UpdateCollectionSerializer(BaseModel): + collection_name: str + vector: List[float] + payload: Dict[str, Any] + id: int + +class DeleteCollectionSerializer(BaseModel): + collection_name: str diff --git a/dev_backend/vector_db_router/views.py b/dev_backend/vector_db_router/views.py index c88a5c5..12ebfb7 100644 --- a/dev_backend/vector_db_router/views.py +++ b/dev_backend/vector_db_router/views.py @@ -1,183 +1,233 @@ -from db_setup import get_qdrant_client -from typing import Annotated -from fastapi import Depends, HTTPException, APIRouter -from qdrant_client import AsyncQdrantClient -from fastapi.responses import JSONResponse -from .serializers import ( - CreateCollectionSerializer, - QueryCollectionSerializer, - UpdateCollectionSerializer, - DeleteCollectionSerializer -) -from model_export.dino_image_matching import get_vectors -from .models import CollectionHandler -import os -app_router = APIRouter() -import logging -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') -log = logging.getLogger(__name__) - -import pandas as pd - -@app_router.get("/get_vectors") -async def get_vectors_endpoint( - q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], - image_path:str=os.getenv("DATASET") -): - try: - # Construct path relative to this file - base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx") - df = pd.read_excel(excel_path) - asin_list = df['ASIN'].dropna().astype(str).tolist() - log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}") - - # 1. Initialize/Create collection "Product" - # DINOv2 vitb14 vector size is 768 - init_handler = CollectionHandler( - collection_name="Product", - vector=[], - vector_size=768, - payload={}, - client=q - ) - await init_handler.create_collection() - - result_lst = [] - for index, row in df.iterrows(): - asin = str(row['ASIN']) - if pd.isna(row['ASIN']): - continue - - title = str(row['Title']) - brand = str(row['Brand']) - link = str(row['Image']) - - # Call get_vectors - vector = get_vectors(image_input=image_path, item=asin) - - if vector is None: - log.warning(f"Skipping {asin} due to missing image/vector") - continue - - payload = { - "asin": asin, - "title": title, - "brand": brand, - "link": link - } - - # 2. Ingest into Qdrant using CollectionHandler - # Use the injected client 'q' and convert index to int - handler = CollectionHandler( - collection_name="Product", - vector=vector, - vector_size=768, - payload=payload, - id=int(index), - link=link, - asin=asin, - brand=brand, - client=q - ) - success = await handler.upsert_point() - - if success: - result_lst.append({ - "item": asin, - "status": "ingested", - "payload": payload - }) - log.info(f"Vector ingested for {asin} (ID: {index})") - else: - log.error(f"Failed to ingest vector for {asin}") - - return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst}) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app_router.post("/create") -async def create_collection_endpoint( - q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], - body: CreateCollectionSerializer = None -): - try: - if body is None: - raise HTTPException(status_code=400, detail="Collection name is required") - - print("collection_name: ", body.collection_name) - - handler = CollectionHandler( - collection_name=body.collection_name, - vector=body.vector, - vector_size=body.vector_size, - payload=body.payload, - id=body.id - ) - - # 1. Create collection - result = await handler.create_collection() - - # 2. Automatically call upsert_point - await handler.upsert_point() - - return JSONResponse(result) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app_router.get("/query") -async def query_collection_endpoint( - q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], - body: QueryCollectionSerializer -): - try: - handler = CollectionHandler( - collection_name=body.collection_name, - vector=body.query_vector, - vector_size=len(body.query_vector), - payload={}, - id=0 - ) - result = await handler.search(body.query_vector) - return JSONResponse({"results": str(result)}) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app_router.put("/update") -async def update_collection_endpoint( - q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], - body: UpdateCollectionSerializer -): - try: - handler = CollectionHandler( - collection_name=body.collection_name, - vector=body.vector, - vector_size=len(body.vector), - payload=body.payload, - id=body.id - ) - result = await handler.update_collection() - return JSONResponse({"status": "success", "result": result}) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app_router.delete("/delete") -async def delete_collection_endpoint( - q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], - body: DeleteCollectionSerializer -): - try: - handler = CollectionHandler( - collection_name=body.collection_name, - vector=[], - vector_size=0, - payload={}, - id=0 - ) - result = await handler.delete_collection() - return JSONResponse(result) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - +from db_setup import get_qdrant_client +from typing import Annotated +from fastapi import Depends, HTTPException, APIRouter +from qdrant_client import AsyncQdrantClient +from .plugins import download_image,read_image +from fastapi.responses import JSONResponse +from .serializers import ( + CreateCollectionSerializer, + QueryCollectionSerializer, + UpdateCollectionSerializer, + DeleteCollectionSerializer +) +from model_export.dino_image_matching import get_vectors,get_embedding +from .models import CollectionHandler +import os +app_router = APIRouter() +import logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +import pandas as pd + +@app_router.get("/get_vectors") +async def get_vectors_endpoint( + q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], + image_path:str=os.getenv("DATASET") +): + try: + # Construct path relative to this file + base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx") + df = pd.read_excel(excel_path) + asin_list = df['ASIN'].dropna().astype(str).tolist() + log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}") + + # 1. Initialize/Create collection "Product" + # DINOv2 vitb14 vector size is 768 + init_handler = CollectionHandler( + collection_name="Product", + vector=[], + vector_size=768, + payload={}, + client=q + ) + await init_handler.create_collection() + + result_lst = [] + for index, row in df.iterrows(): + asin = str(row['ASIN']) + if pd.isna(row['ASIN']): + continue + + title = str(row['Title']) + brand = str(row['Brand']) + link = str(row['Image']) + + # Call get_vectors + vector = get_vectors(image_input=image_path, item=asin) + + if vector is None: + log.warning(f"Skipping {asin} due to missing image/vector") + continue + + payload = { + "asin": asin, + "title": title, + "brand": brand, + "link": link + } + + # 2. Ingest into Qdrant using CollectionHandler + # Use the injected client 'q' and convert index to int + handler = CollectionHandler( + collection_name="Product", + vector=vector, + vector_size=768, + payload=payload, + id=int(index), + link=link, + asin=asin, + brand=brand, + client=q + ) + success = await handler.upsert_point() + + if success: + result_lst.append({ + "item": asin, + "status": "ingested", + "payload": payload + }) + log.info(f"Vector ingested for {asin} (ID: {index})") + else: + log.error(f"Failed to ingest vector for {asin}") + + return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app_router.post("/create") +async def create_collection_endpoint( + q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], + body: CreateCollectionSerializer = None +): + try: + if body is None: + raise HTTPException(status_code=400, detail="Collection name is required") + + print("collection_name: ", body.collection_name) + + handler = CollectionHandler( + collection_name=body.collection_name, + vector=body.vector, + vector_size=body.vector_size, + payload=body.payload, + id=body.id + ) + + # 1. Create collection + result = await handler.create_collection() + + # 2. Automatically call upsert_point + await handler.upsert_point() + + return JSONResponse(result) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app_router.get("/query") +async def query_collection_endpoint( + q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], + body: QueryCollectionSerializer +): + try: + result = [] + if isinstance(body.url, str): + # Handle semicolon-separated URLs by taking the first one + target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url + log.info(f"Querying collection {body.collection_name} with URL: {target_url}") + downloaded_image_path = download_image(target_url) + query_vector = get_embedding(downloaded_image_path) + # get_embedding already returns a flat list of 768 floats + + handler = CollectionHandler( + collection_name=body.collection_name, + vector=query_vector, + vector_size=len(query_vector), + payload={}, + id=0, + client=q + ) + search_result = await handler.search( + query_vector, + score_threshold=body.score_threshold, + limit=body.limit + ) + if search_result: + result = [ + {"id": p.id, "score": p.score, "payload": p.payload} + for p in search_result + ] + else: + result = [] # No match within threshold + + elif isinstance(body.url, list): + result = [] + for url in body.url: + downloaded_image_path = download_image(url) + query_vector = get_embedding(downloaded_image_path) + # get_embedding already returns a flat list of 768 floats + + handler = CollectionHandler( + collection_name=body.collection_name, + vector=query_vector, + vector_size=len(query_vector), + payload={}, + id=0, + client=q + ) + search_result = await handler.search( + query_vector, + score_threshold=body.score_threshold, + limit=body.limit + ) + if search_result: + result.append([ + {"id": p.id, "score": p.score, "payload": p.payload} + for p in search_result + ]) + + return JSONResponse({"results": result}) + except Exception as e: + log.error(f"Query failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app_router.put("/update") +async def update_collection_endpoint( + q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], + body: UpdateCollectionSerializer +): + try: + handler = CollectionHandler( + collection_name=body.collection_name, + vector=body.vector, + vector_size=len(body.vector), + payload=body.payload, + id=body.id + ) + result = await handler.update_collection() + return JSONResponse({"status": "success", "result": result}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app_router.delete("/delete") +async def delete_collection_endpoint( + q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], + body: DeleteCollectionSerializer +): + try: + handler = CollectionHandler( + collection_name=body.collection_name, + vector=[], + vector_size=0, + payload={}, + id=0 + ) + result = await handler.delete_collection() + return JSONResponse(result) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + \ No newline at end of file diff --git a/model_export/dino_image_matching.py b/model_export/dino_image_matching.py index 54c369f..70edcc2 100644 --- a/model_export/dino_image_matching.py +++ b/model_export/dino_image_matching.py @@ -46,7 +46,8 @@ def get_embedding(image_path): # Normalize embedding (important for cosine similarity) embedding = F.normalize(embedding, p=2, dim=1) - return embedding.cpu() + # Return flat list (squeeze batch dim) + return embedding.squeeze(0).cpu().tolist() def get_vectors(image_input, item): try: