Compare commits

..

No commits in common. "6ed79621b06617e82194c2b73eaeeba2b98bc3dd" and "56f52a4edcf6913ab0558a1e8afa572cbf30a57d" have entirely different histories.

6 changed files with 311 additions and 458 deletions

4
.gitignore vendored
View File

@ -5,5 +5,5 @@
**pycache** **pycache**
*agent/** *agent/**
**downloaded_images** **downloaded_images**
**model_export** __pycache__/
**cpython** *.pyc

View File

@ -1,9 +1,6 @@
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from typing import Dict, Any from typing import Dict, Any
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
class CollectionHandler: class CollectionHandler:
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
@ -82,17 +79,16 @@ class CollectionHandler:
async def upsert_point(self): async def upsert_point(self):
return await self.insertion() return await self.insertion()
async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10): async def search(self, query_vector):
try: try:
result = await self.client.query_points( result = await self.client.search(
collection_name=self.collection_name, collection_name=self.collection_name,
query=query_vector, query_vector=query_vector,
score_threshold=score_threshold, limit=10
limit=limit
) )
return result.points return result
except Exception as e: except Exception as e:
log.error(f"Search failed: {e}") print("Search failed: ", e)
return None return None
async def update_collection(self): async def update_collection(self):

View File

@ -1,90 +0,0 @@
import os
import requests
from urllib.parse import urlparse
from pathlib import Path
from PIL import Image
def download_image(url: str, filename: str = None) -> str:
"""
Download an image from URL and save it in data/temp/ folder.
Args:
url (str): Image URL
filename (str, optional): Custom filename. If None, extracted from URL.
Returns:
str: Full path to the downloaded image
"""
try:
# Get project root directory (where your main script is)
root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent
# Create data/temp folder structure
temp_dir = root_dir / "data" / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
# Generate filename if not provided
if not filename:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)
if not filename or "." not in filename:
# Fallback filename
ext = filename.split('.')[-1] if '.' in filename else 'jpg'
filename = f"image_{hash(url) % 100000}.{ext}"
# Ensure filename has extension
if '.' not in filename:
filename += ".jpg"
file_path = temp_dir / filename
# Download the image
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Save image
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Image downloaded: {file_path}")
return str(file_path)
except Exception as e:
print(f"❌ Failed to download image: {e}")
raise
def read_image(image_path: str) -> Image.Image:
"""
Read an image from the given path and return a PIL Image object.
Args:
image_path (str): Path to the image file
Returns:
PIL.Image.Image: Loaded image
Raises:
FileNotFoundError: If image doesn't exist
Exception: For other image loading errors
"""
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at path: {image_path}")
# Open the image
image = Image.open(image_path)
# Convert to RGB (important for DINOv2 and most models)
if image.mode != "RGB":
image = image.convert("RGB")
print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}")
return image
except FileNotFoundError as e:
print(f"❌ File not found: {e}")
raise
except Exception as e:
print(f"❌ Failed to read image: {e}")
raise

View File

@ -10,9 +10,7 @@ class CreateCollectionSerializer(BaseModel):
class QueryCollectionSerializer(BaseModel): class QueryCollectionSerializer(BaseModel):
collection_name: str collection_name: str
url: str query_vector: List[float]
score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match
limit: int = 10
class UpdateCollectionSerializer(BaseModel): class UpdateCollectionSerializer(BaseModel):
collection_name: str collection_name: str

View File

@ -2,7 +2,6 @@ from db_setup import get_qdrant_client
from typing import Annotated from typing import Annotated
from fastapi import Depends, HTTPException, APIRouter from fastapi import Depends, HTTPException, APIRouter
from qdrant_client import AsyncQdrantClient from qdrant_client import AsyncQdrantClient
from .plugins import download_image,read_image
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from .serializers import ( from .serializers import (
CreateCollectionSerializer, CreateCollectionSerializer,
@ -10,7 +9,7 @@ from .serializers import (
UpdateCollectionSerializer, UpdateCollectionSerializer,
DeleteCollectionSerializer DeleteCollectionSerializer
) )
from model_export.dino_image_matching import get_vectors,get_embedding from model_export.dino_image_matching import get_vectors
from .models import CollectionHandler from .models import CollectionHandler
import os import os
app_router = APIRouter() app_router = APIRouter()
@ -133,65 +132,16 @@ async def query_collection_endpoint(
body: QueryCollectionSerializer body: QueryCollectionSerializer
): ):
try: try:
result = []
if isinstance(body.url, str):
# Handle semicolon-separated URLs by taking the first one
target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url
log.info(f"Querying collection {body.collection_name} with URL: {target_url}")
downloaded_image_path = download_image(target_url)
query_vector = get_embedding(downloaded_image_path)
# get_embedding already returns a flat list of 768 floats
handler = CollectionHandler( handler = CollectionHandler(
collection_name=body.collection_name, collection_name=body.collection_name,
vector=query_vector, vector=body.query_vector,
vector_size=len(query_vector), vector_size=len(body.query_vector),
payload={}, payload={},
id=0, id=0
client=q
) )
search_result = await handler.search( result = await handler.search(body.query_vector)
query_vector, return JSONResponse({"results": str(result)})
score_threshold=body.score_threshold,
limit=body.limit
)
if search_result:
result = [
{"id": p.id, "score": p.score, "payload": p.payload}
for p in search_result
]
else:
result = [] # No match within threshold
elif isinstance(body.url, list):
result = []
for url in body.url:
downloaded_image_path = download_image(url)
query_vector = get_embedding(downloaded_image_path)
# get_embedding already returns a flat list of 768 floats
handler = CollectionHandler(
collection_name=body.collection_name,
vector=query_vector,
vector_size=len(query_vector),
payload={},
id=0,
client=q
)
search_result = await handler.search(
query_vector,
score_threshold=body.score_threshold,
limit=body.limit
)
if search_result:
result.append([
{"id": p.id, "score": p.score, "payload": p.payload}
for p in search_result
])
return JSONResponse({"results": result})
except Exception as e: except Exception as e:
log.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app_router.put("/update") @app_router.put("/update")

View File

@ -46,8 +46,7 @@ def get_embedding(image_path):
# Normalize embedding (important for cosine similarity) # Normalize embedding (important for cosine similarity)
embedding = F.normalize(embedding, p=2, dim=1) embedding = F.normalize(embedding, p=2, dim=1)
# Return flat list (squeeze batch dim) return embedding.cpu()
return embedding.squeeze(0).cpu().tolist()
def get_vectors(image_input, item): def get_vectors(image_input, item):
try: try: