listing-radar/dev_backend/vector_db_router/views.py

233 lines
8.1 KiB
Python

from db_setup import get_qdrant_client
from typing import Annotated
from fastapi import Depends, HTTPException, APIRouter
from qdrant_client import AsyncQdrantClient
from .plugins import download_image,read_image
from fastapi.responses import JSONResponse
from .serializers import (
CreateCollectionSerializer,
QueryCollectionSerializer,
UpdateCollectionSerializer,
DeleteCollectionSerializer
)
from model_export.dino_image_matching import get_vectors,get_embedding
from .models import CollectionHandler
import os
app_router = APIRouter()
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
import pandas as pd
@app_router.get("/get_vectors")
async def get_vectors_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
image_path:str=os.getenv("DATASET")
):
try:
# Construct path relative to this file
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
df = pd.read_excel(excel_path)
asin_list = df['ASIN'].dropna().astype(str).tolist()
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
# 1. Initialize/Create collection "Product"
# DINOv2 vitb14 vector size is 768
init_handler = CollectionHandler(
collection_name="Product",
vector=[],
vector_size=768,
payload={},
client=q
)
await init_handler.create_collection()
result_lst = []
for index, row in df.iterrows():
asin = str(row['ASIN'])
if pd.isna(row['ASIN']):
continue
title = str(row['Title'])
brand = str(row['Brand'])
link = str(row['Image'])
# Call get_vectors
vector = get_vectors(image_input=image_path, item=asin)
if vector is None:
log.warning(f"Skipping {asin} due to missing image/vector")
continue
payload = {
"asin": asin,
"title": title,
"brand": brand,
"link": link
}
# 2. Ingest into Qdrant using CollectionHandler
# Use the injected client 'q' and convert index to int
handler = CollectionHandler(
collection_name="Product",
vector=vector,
vector_size=768,
payload=payload,
id=int(index),
link=link,
asin=asin,
brand=brand,
client=q
)
success = await handler.upsert_point()
if success:
result_lst.append({
"item": asin,
"status": "ingested",
"payload": payload
})
log.info(f"Vector ingested for {asin} (ID: {index})")
else:
log.error(f"Failed to ingest vector for {asin}")
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app_router.post("/create")
async def create_collection_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
body: CreateCollectionSerializer = None
):
try:
if body is None:
raise HTTPException(status_code=400, detail="Collection name is required")
print("collection_name: ", body.collection_name)
handler = CollectionHandler(
collection_name=body.collection_name,
vector=body.vector,
vector_size=body.vector_size,
payload=body.payload,
id=body.id
)
# 1. Create collection
result = await handler.create_collection()
# 2. Automatically call upsert_point
await handler.upsert_point()
return JSONResponse(result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app_router.get("/query")
async def query_collection_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
body: QueryCollectionSerializer
):
try:
result = []
if isinstance(body.url, str):
# Handle semicolon-separated URLs by taking the first one
target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url
log.info(f"Querying collection {body.collection_name} with URL: {target_url}")
downloaded_image_path = download_image(target_url)
query_vector = get_embedding(downloaded_image_path)
# get_embedding already returns a flat list of 768 floats
handler = CollectionHandler(
collection_name=body.collection_name,
vector=query_vector,
vector_size=len(query_vector),
payload={},
id=0,
client=q
)
search_result = await handler.search(
query_vector,
score_threshold=body.score_threshold,
limit=body.limit
)
if search_result:
result = [
{"id": p.id, "score": p.score, "payload": p.payload}
for p in search_result
]
else:
result = [] # No match within threshold
elif isinstance(body.url, list):
result = []
for url in body.url:
downloaded_image_path = download_image(url)
query_vector = get_embedding(downloaded_image_path)
# get_embedding already returns a flat list of 768 floats
handler = CollectionHandler(
collection_name=body.collection_name,
vector=query_vector,
vector_size=len(query_vector),
payload={},
id=0,
client=q
)
search_result = await handler.search(
query_vector,
score_threshold=body.score_threshold,
limit=body.limit
)
if search_result:
result.append([
{"id": p.id, "score": p.score, "payload": p.payload}
for p in search_result
])
return JSONResponse({"results": result})
except Exception as e:
log.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app_router.put("/update")
async def update_collection_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
body: UpdateCollectionSerializer
):
try:
handler = CollectionHandler(
collection_name=body.collection_name,
vector=body.vector,
vector_size=len(body.vector),
payload=body.payload,
id=body.id
)
result = await handler.update_collection()
return JSONResponse({"status": "success", "result": result})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app_router.delete("/delete")
async def delete_collection_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
body: DeleteCollectionSerializer
):
try:
handler = CollectionHandler(
collection_name=body.collection_name,
vector=[],
vector_size=0,
payload={},
id=0
)
result = await handler.delete_collection()
return JSONResponse(result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))