Compare commits

..

No commits in common. "6ed79621b06617e82194c2b73eaeeba2b98bc3dd" and "56f52a4edcf6913ab0558a1e8afa572cbf30a57d" have entirely different histories.

6 changed files with 311 additions and 458 deletions

4
.gitignore vendored
View File

@ -5,5 +5,5 @@
**pycache** **pycache**
*agent/** *agent/**
**downloaded_images** **downloaded_images**
**model_export** __pycache__/
**cpython** *.pyc

View File

@ -1,108 +1,104 @@
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from typing import Dict, Any from typing import Dict, Any
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') class CollectionHandler:
log = logging.getLogger(__name__) def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
id: int=None,
class CollectionHandler: link: str=None,
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, asin: str=None,
id: int=None, category: str=None,
link: str=None, brand: str=None,
asin: str=None, client: AsyncQdrantClient=None
category: str=None, ):
brand: str=None, self.collection_name = collection_name
client: AsyncQdrantClient=None self.vector = vector
): self.id = id
self.collection_name = collection_name self.vector_size = vector_size
self.vector = vector self.payload = payload
self.id = id self.link = link
self.vector_size = vector_size self.asin = asin
self.payload = payload self.category = category
self.link = link self.brand = brand
self.asin = asin self.client = client if client else AsyncQdrantClient("localhost", port=6333)
self.category = category
self.brand = brand async def create_collection(self):
self.client = client if client else AsyncQdrantClient("localhost", port=6333) try:
if await self.client.collection_exists(self.collection_name):
async def create_collection(self): return {"message": "Collection already exists"}
try:
if await self.client.collection_exists(self.collection_name): await self.client.create_collection(
return {"message": "Collection already exists"} collection_name=self.collection_name,
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
await self.client.create_collection( optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
collection_name=self.collection_name, )
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000) # Creating payload indexes as per project logic
)
await self.client.create_payload_index(
# Creating payload indexes as per project logic collection_name=self.collection_name,
field_name="link",
await self.client.create_payload_index( field_schema=models.PayloadSchemaType.KEYWORD
collection_name=self.collection_name, )
field_name="link", await self.client.create_payload_index(
field_schema=models.PayloadSchemaType.KEYWORD collection_name=self.collection_name,
) field_name="title",
await self.client.create_payload_index( field_schema=models.PayloadSchemaType.KEYWORD
collection_name=self.collection_name, )
field_name="title",
field_schema=models.PayloadSchemaType.KEYWORD await self.client.create_payload_index(
) collection_name=self.collection_name,
field_name="brand",
await self.client.create_payload_index( field_schema=models.PayloadSchemaType.KEYWORD
collection_name=self.collection_name, )
field_name="brand", await self.client.create_payload_index(
field_schema=models.PayloadSchemaType.KEYWORD collection_name=self.collection_name,
) field_name="asin",
await self.client.create_payload_index( field_schema=models.PayloadSchemaType.KEYWORD
collection_name=self.collection_name, )
field_name="asin",
field_schema=models.PayloadSchemaType.KEYWORD
) return {"message": f"Collection {self.collection_name} created successfully"}
except Exception as e:
return {"message": str(e)}
return {"message": f"Collection {self.collection_name} created successfully"}
except Exception as e: async def insertion(self):
return {"message": str(e)} try:
await self.client.upsert(
async def insertion(self): collection_name=self.collection_name,
try: points=[
await self.client.upsert( PointStruct(id=self.id, vector=self.vector, payload=self.payload)
collection_name=self.collection_name, ]
points=[ )
PointStruct(id=self.id, vector=self.vector, payload=self.payload) return True
] except Exception as e:
) # Note: In a real app we'd use a logger here
return True print(f"Insertion failed for ID {self.id}: {e}")
except Exception as e: return False
# Note: In a real app we'd use a logger here
print(f"Insertion failed for ID {self.id}: {e}") async def upsert_point(self):
return False return await self.insertion()
async def upsert_point(self): async def search(self, query_vector):
return await self.insertion() try:
result = await self.client.search(
async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10): collection_name=self.collection_name,
try: query_vector=query_vector,
result = await self.client.query_points( limit=10
collection_name=self.collection_name, )
query=query_vector, return result
score_threshold=score_threshold, except Exception as e:
limit=limit print("Search failed: ", e)
) return None
return result.points
except Exception as e: async def update_collection(self):
log.error(f"Search failed: {e}") """Update is implemented as an upsert of the point data."""
return None return await self.upsert_point()
async def update_collection(self): async def delete_collection(self):
"""Update is implemented as an upsert of the point data.""" try:
return await self.upsert_point() await self.client.delete_collection(collection_name=self.collection_name)
return {"message": f"Collection {self.collection_name} deleted successfully"}
async def delete_collection(self): except Exception as e:
try: return {"message": str(e)}
await self.client.delete_collection(collection_name=self.collection_name)
return {"message": f"Collection {self.collection_name} deleted successfully"}
except Exception as e:
return {"message": str(e)}

View File

@ -1,90 +0,0 @@
import os
import requests
from urllib.parse import urlparse
from pathlib import Path
from PIL import Image
def download_image(url: str, filename: str = None) -> str:
"""
Download an image from URL and save it in data/temp/ folder.
Args:
url (str): Image URL
filename (str, optional): Custom filename. If None, extracted from URL.
Returns:
str: Full path to the downloaded image
"""
try:
# Get project root directory (where your main script is)
root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent
# Create data/temp folder structure
temp_dir = root_dir / "data" / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
# Generate filename if not provided
if not filename:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)
if not filename or "." not in filename:
# Fallback filename
ext = filename.split('.')[-1] if '.' in filename else 'jpg'
filename = f"image_{hash(url) % 100000}.{ext}"
# Ensure filename has extension
if '.' not in filename:
filename += ".jpg"
file_path = temp_dir / filename
# Download the image
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Save image
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Image downloaded: {file_path}")
return str(file_path)
except Exception as e:
print(f"❌ Failed to download image: {e}")
raise
def read_image(image_path: str) -> Image.Image:
"""
Read an image from the given path and return a PIL Image object.
Args:
image_path (str): Path to the image file
Returns:
PIL.Image.Image: Loaded image
Raises:
FileNotFoundError: If image doesn't exist
Exception: For other image loading errors
"""
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at path: {image_path}")
# Open the image
image = Image.open(image_path)
# Convert to RGB (important for DINOv2 and most models)
if image.mode != "RGB":
image = image.convert("RGB")
print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}")
return image
except FileNotFoundError as e:
print(f"❌ File not found: {e}")
raise
except Exception as e:
print(f"❌ Failed to read image: {e}")
raise

View File

@ -1,24 +1,22 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, List, Any from typing import Dict, List, Any
class CreateCollectionSerializer(BaseModel): class CreateCollectionSerializer(BaseModel):
collection_name: str collection_name: str
vector: List[float] vector: List[float]
vector_size: int vector_size: int
payload: Dict[str, Any] payload: Dict[str, Any]
id: int id: int
class QueryCollectionSerializer(BaseModel): class QueryCollectionSerializer(BaseModel):
collection_name: str collection_name: str
url: str query_vector: List[float]
score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match
limit: int = 10 class UpdateCollectionSerializer(BaseModel):
collection_name: str
class UpdateCollectionSerializer(BaseModel): vector: List[float]
collection_name: str payload: Dict[str, Any]
vector: List[float] id: int
payload: Dict[str, Any]
id: int class DeleteCollectionSerializer(BaseModel):
collection_name: str
class DeleteCollectionSerializer(BaseModel):
collection_name: str

View File

@ -1,233 +1,183 @@
from db_setup import get_qdrant_client from db_setup import get_qdrant_client
from typing import Annotated from typing import Annotated
from fastapi import Depends, HTTPException, APIRouter from fastapi import Depends, HTTPException, APIRouter
from qdrant_client import AsyncQdrantClient from qdrant_client import AsyncQdrantClient
from .plugins import download_image,read_image from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse from .serializers import (
from .serializers import ( CreateCollectionSerializer,
CreateCollectionSerializer, QueryCollectionSerializer,
QueryCollectionSerializer, UpdateCollectionSerializer,
UpdateCollectionSerializer, DeleteCollectionSerializer
DeleteCollectionSerializer )
) from model_export.dino_image_matching import get_vectors
from model_export.dino_image_matching import get_vectors,get_embedding from .models import CollectionHandler
from .models import CollectionHandler import os
import os app_router = APIRouter()
app_router = APIRouter() import logging
import logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
import pandas as pd
import pandas as pd
@app_router.get("/get_vectors")
@app_router.get("/get_vectors") async def get_vectors_endpoint(
async def get_vectors_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], image_path:str=os.getenv("DATASET")
image_path:str=os.getenv("DATASET") ):
): try:
try: # Construct path relative to this file
# Construct path relative to this file base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx") df = pd.read_excel(excel_path)
df = pd.read_excel(excel_path) asin_list = df['ASIN'].dropna().astype(str).tolist()
asin_list = df['ASIN'].dropna().astype(str).tolist() log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
# 1. Initialize/Create collection "Product"
# 1. Initialize/Create collection "Product" # DINOv2 vitb14 vector size is 768
# DINOv2 vitb14 vector size is 768 init_handler = CollectionHandler(
init_handler = CollectionHandler( collection_name="Product",
collection_name="Product", vector=[],
vector=[], vector_size=768,
vector_size=768, payload={},
payload={}, client=q
client=q )
) await init_handler.create_collection()
await init_handler.create_collection()
result_lst = []
result_lst = [] for index, row in df.iterrows():
for index, row in df.iterrows(): asin = str(row['ASIN'])
asin = str(row['ASIN']) if pd.isna(row['ASIN']):
if pd.isna(row['ASIN']): continue
continue
title = str(row['Title'])
title = str(row['Title']) brand = str(row['Brand'])
brand = str(row['Brand']) link = str(row['Image'])
link = str(row['Image'])
# Call get_vectors
# Call get_vectors vector = get_vectors(image_input=image_path, item=asin)
vector = get_vectors(image_input=image_path, item=asin)
if vector is None:
if vector is None: log.warning(f"Skipping {asin} due to missing image/vector")
log.warning(f"Skipping {asin} due to missing image/vector") continue
continue
payload = {
payload = { "asin": asin,
"asin": asin, "title": title,
"title": title, "brand": brand,
"brand": brand, "link": link
"link": link }
}
# 2. Ingest into Qdrant using CollectionHandler
# 2. Ingest into Qdrant using CollectionHandler # Use the injected client 'q' and convert index to int
# Use the injected client 'q' and convert index to int handler = CollectionHandler(
handler = CollectionHandler( collection_name="Product",
collection_name="Product", vector=vector,
vector=vector, vector_size=768,
vector_size=768, payload=payload,
payload=payload, id=int(index),
id=int(index), link=link,
link=link, asin=asin,
asin=asin, brand=brand,
brand=brand, client=q
client=q )
) success = await handler.upsert_point()
success = await handler.upsert_point()
if success:
if success: result_lst.append({
result_lst.append({ "item": asin,
"item": asin, "status": "ingested",
"status": "ingested", "payload": payload
"payload": payload })
}) log.info(f"Vector ingested for {asin} (ID: {index})")
log.info(f"Vector ingested for {asin} (ID: {index})") else:
else: log.error(f"Failed to ingest vector for {asin}")
log.error(f"Failed to ingest vector for {asin}")
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst}) except Exception as e:
except Exception as e: raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@app_router.post("/create")
@app_router.post("/create") async def create_collection_endpoint(
async def create_collection_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], body: CreateCollectionSerializer = None
body: CreateCollectionSerializer = None ):
): try:
try: if body is None:
if body is None: raise HTTPException(status_code=400, detail="Collection name is required")
raise HTTPException(status_code=400, detail="Collection name is required")
print("collection_name: ", body.collection_name)
print("collection_name: ", body.collection_name)
handler = CollectionHandler(
handler = CollectionHandler( collection_name=body.collection_name,
collection_name=body.collection_name, vector=body.vector,
vector=body.vector, vector_size=body.vector_size,
vector_size=body.vector_size, payload=body.payload,
payload=body.payload, id=body.id
id=body.id )
)
# 1. Create collection
# 1. Create collection result = await handler.create_collection()
result = await handler.create_collection()
# 2. Automatically call upsert_point
# 2. Automatically call upsert_point await handler.upsert_point()
await handler.upsert_point()
return JSONResponse(result)
return JSONResponse(result)
except Exception as e:
except Exception as e: raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@app_router.get("/query")
@app_router.get("/query") async def query_collection_endpoint(
async def query_collection_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], body: QueryCollectionSerializer
body: QueryCollectionSerializer ):
): try:
try: handler = CollectionHandler(
result = [] collection_name=body.collection_name,
if isinstance(body.url, str): vector=body.query_vector,
# Handle semicolon-separated URLs by taking the first one vector_size=len(body.query_vector),
target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url payload={},
log.info(f"Querying collection {body.collection_name} with URL: {target_url}") id=0
downloaded_image_path = download_image(target_url) )
query_vector = get_embedding(downloaded_image_path) result = await handler.search(body.query_vector)
# get_embedding already returns a flat list of 768 floats return JSONResponse({"results": str(result)})
except Exception as e:
handler = CollectionHandler( raise HTTPException(status_code=500, detail=str(e))
collection_name=body.collection_name,
vector=query_vector, @app_router.put("/update")
vector_size=len(query_vector), async def update_collection_endpoint(
payload={}, q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
id=0, body: UpdateCollectionSerializer
client=q ):
) try:
search_result = await handler.search( handler = CollectionHandler(
query_vector, collection_name=body.collection_name,
score_threshold=body.score_threshold, vector=body.vector,
limit=body.limit vector_size=len(body.vector),
) payload=body.payload,
if search_result: id=body.id
result = [ )
{"id": p.id, "score": p.score, "payload": p.payload} result = await handler.update_collection()
for p in search_result return JSONResponse({"status": "success", "result": result})
] except Exception as e:
else: raise HTTPException(status_code=500, detail=str(e))
result = [] # No match within threshold
@app_router.delete("/delete")
elif isinstance(body.url, list): async def delete_collection_endpoint(
result = [] q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
for url in body.url: body: DeleteCollectionSerializer
downloaded_image_path = download_image(url) ):
query_vector = get_embedding(downloaded_image_path) try:
# get_embedding already returns a flat list of 768 floats handler = CollectionHandler(
collection_name=body.collection_name,
handler = CollectionHandler( vector=[],
collection_name=body.collection_name, vector_size=0,
vector=query_vector, payload={},
vector_size=len(query_vector), id=0
payload={}, )
id=0, result = await handler.delete_collection()
client=q return JSONResponse(result)
) except Exception as e:
search_result = await handler.search( raise HTTPException(status_code=500, detail=str(e))
query_vector,
score_threshold=body.score_threshold,
limit=body.limit
)
if search_result:
result.append([
{"id": p.id, "score": p.score, "payload": p.payload}
for p in search_result
])
return JSONResponse({"results": result})
except Exception as e:
log.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app_router.put("/update")
async def update_collection_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
body: UpdateCollectionSerializer
):
try:
handler = CollectionHandler(
collection_name=body.collection_name,
vector=body.vector,
vector_size=len(body.vector),
payload=body.payload,
id=body.id
)
result = await handler.update_collection()
return JSONResponse({"status": "success", "result": result})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app_router.delete("/delete")
async def delete_collection_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
body: DeleteCollectionSerializer
):
try:
handler = CollectionHandler(
collection_name=body.collection_name,
vector=[],
vector_size=0,
payload={},
id=0
)
result = await handler.delete_collection()
return JSONResponse(result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@ -46,8 +46,7 @@ def get_embedding(image_path):
# Normalize embedding (important for cosine similarity) # Normalize embedding (important for cosine similarity)
embedding = F.normalize(embedding, p=2, dim=1) embedding = F.normalize(embedding, p=2, dim=1)
# Return flat list (squeeze batch dim) return embedding.cpu()
return embedding.squeeze(0).cpu().tolist()
def get_vectors(image_input, item): def get_vectors(image_input, item):
try: try: