diff --git a/.gitignore b/.gitignore index efe7b25..13ffc51 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ *pyc** **pycache** *agent/** -**downloaded_images** \ No newline at end of file +**downloaded_images** +**model_export** +**cpython** \ No newline at end of file diff --git a/dev_backend/__pycache__/db_setup.cpython-313.pyc b/dev_backend/__pycache__/db_setup.cpython-313.pyc deleted file mode 100644 index cdc274a..0000000 Binary files a/dev_backend/__pycache__/db_setup.cpython-313.pyc and /dev/null differ diff --git a/dev_backend/__pycache__/main.cpython-313.pyc b/dev_backend/__pycache__/main.cpython-313.pyc deleted file mode 100644 index a33d8a7..0000000 Binary files a/dev_backend/__pycache__/main.cpython-313.pyc and /dev/null differ diff --git a/dev_backend/mysql_process/__pycache__/views.cpython-313.pyc b/dev_backend/mysql_process/__pycache__/views.cpython-313.pyc deleted file mode 100644 index 86d7527..0000000 Binary files a/dev_backend/mysql_process/__pycache__/views.cpython-313.pyc and /dev/null differ diff --git a/dev_backend/vector_db_router/models.py b/dev_backend/vector_db_router/models.py index 16cac52..a2abec8 100644 --- a/dev_backend/vector_db_router/models.py +++ b/dev_backend/vector_db_router/models.py @@ -1,6 +1,9 @@ from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct from typing import Dict, Any +import logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) class CollectionHandler: def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, @@ -79,16 +82,17 @@ class CollectionHandler: async def upsert_point(self): return await self.insertion() - async def search(self, query_vector): + async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10): try: - result = await self.client.search( + result = await self.client.query_points( collection_name=self.collection_name, - query_vector=query_vector, - limit=10 + query=query_vector, + score_threshold=score_threshold, + limit=limit ) - return result + return result.points except Exception as e: - print("Search failed: ", e) + log.error(f"Search failed: {e}") return None async def update_collection(self): diff --git a/dev_backend/vector_db_router/plugins.py b/dev_backend/vector_db_router/plugins.py new file mode 100644 index 0000000..a500340 --- /dev/null +++ b/dev_backend/vector_db_router/plugins.py @@ -0,0 +1,90 @@ +import os +import requests +from urllib.parse import urlparse +from pathlib import Path +from PIL import Image + +def download_image(url: str, filename: str = None) -> str: + """ + Download an image from URL and save it in data/temp/ folder. + + Args: + url (str): Image URL + filename (str, optional): Custom filename. If None, extracted from URL. + + Returns: + str: Full path to the downloaded image + """ + try: + # Get project root directory (where your main script is) + root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent + + # Create data/temp folder structure + temp_dir = root_dir / "data" / "temp" + temp_dir.mkdir(parents=True, exist_ok=True) + + # Generate filename if not provided + if not filename: + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path) + if not filename or "." not in filename: + # Fallback filename + ext = filename.split('.')[-1] if '.' in filename else 'jpg' + filename = f"image_{hash(url) % 100000}.{ext}" + + # Ensure filename has extension + if '.' not in filename: + filename += ".jpg" + + file_path = temp_dir / filename + + # Download the image + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + # Save image + with open(file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"✅ Image downloaded: {file_path}") + return str(file_path) + + except Exception as e: + print(f"❌ Failed to download image: {e}") + raise + +def read_image(image_path: str) -> Image.Image: + """ + Read an image from the given path and return a PIL Image object. + + Args: + image_path (str): Path to the image file + + Returns: + PIL.Image.Image: Loaded image + + Raises: + FileNotFoundError: If image doesn't exist + Exception: For other image loading errors + """ + try: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image not found at path: {image_path}") + + # Open the image + image = Image.open(image_path) + + # Convert to RGB (important for DINOv2 and most models) + if image.mode != "RGB": + image = image.convert("RGB") + + print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}") + return image + + except FileNotFoundError as e: + print(f"❌ File not found: {e}") + raise + except Exception as e: + print(f"❌ Failed to read image: {e}") + raise \ No newline at end of file diff --git a/dev_backend/vector_db_router/serializers.py b/dev_backend/vector_db_router/serializers.py index dd43e20..7682200 100644 --- a/dev_backend/vector_db_router/serializers.py +++ b/dev_backend/vector_db_router/serializers.py @@ -10,7 +10,9 @@ class CreateCollectionSerializer(BaseModel): class QueryCollectionSerializer(BaseModel): collection_name: str - query_vector: List[float] + url: str + score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match + limit: int = 10 class UpdateCollectionSerializer(BaseModel): collection_name: str diff --git a/dev_backend/vector_db_router/views.py b/dev_backend/vector_db_router/views.py index b87275e..12ebfb7 100644 --- a/dev_backend/vector_db_router/views.py +++ b/dev_backend/vector_db_router/views.py @@ -2,6 +2,7 @@ from db_setup import get_qdrant_client from typing import Annotated from fastapi import Depends, HTTPException, APIRouter from qdrant_client import AsyncQdrantClient +from .plugins import download_image,read_image from fastapi.responses import JSONResponse from .serializers import ( CreateCollectionSerializer, @@ -9,7 +10,7 @@ from .serializers import ( UpdateCollectionSerializer, DeleteCollectionSerializer ) -from model_export.dino_image_matching import get_vectors +from model_export.dino_image_matching import get_vectors,get_embedding from .models import CollectionHandler import os app_router = APIRouter() @@ -132,16 +133,65 @@ async def query_collection_endpoint( body: QueryCollectionSerializer ): try: - handler = CollectionHandler( - collection_name=body.collection_name, - vector=body.query_vector, - vector_size=len(body.query_vector), - payload={}, - id=0 - ) - result = await handler.search(body.query_vector) - return JSONResponse({"results": str(result)}) + result = [] + if isinstance(body.url, str): + # Handle semicolon-separated URLs by taking the first one + target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url + log.info(f"Querying collection {body.collection_name} with URL: {target_url}") + downloaded_image_path = download_image(target_url) + query_vector = get_embedding(downloaded_image_path) + # get_embedding already returns a flat list of 768 floats + + handler = CollectionHandler( + collection_name=body.collection_name, + vector=query_vector, + vector_size=len(query_vector), + payload={}, + id=0, + client=q + ) + search_result = await handler.search( + query_vector, + score_threshold=body.score_threshold, + limit=body.limit + ) + if search_result: + result = [ + {"id": p.id, "score": p.score, "payload": p.payload} + for p in search_result + ] + else: + result = [] # No match within threshold + + elif isinstance(body.url, list): + result = [] + for url in body.url: + downloaded_image_path = download_image(url) + query_vector = get_embedding(downloaded_image_path) + # get_embedding already returns a flat list of 768 floats + + handler = CollectionHandler( + collection_name=body.collection_name, + vector=query_vector, + vector_size=len(query_vector), + payload={}, + id=0, + client=q + ) + search_result = await handler.search( + query_vector, + score_threshold=body.score_threshold, + limit=body.limit + ) + if search_result: + result.append([ + {"id": p.id, "score": p.score, "payload": p.payload} + for p in search_result + ]) + + return JSONResponse({"results": result}) except Exception as e: + log.error(f"Query failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app_router.put("/update") diff --git a/model_export/dino_image_matching.py b/model_export/dino_image_matching.py index 54c369f..70edcc2 100644 --- a/model_export/dino_image_matching.py +++ b/model_export/dino_image_matching.py @@ -46,7 +46,8 @@ def get_embedding(image_path): # Normalize embedding (important for cosine similarity) embedding = F.normalize(embedding, p=2, dim=1) - return embedding.cpu() + # Return flat list (squeeze batch dim) + return embedding.squeeze(0).cpu().tolist() def get_vectors(image_input, item): try: