from db_setup import get_qdrant_client from typing import Annotated from fastapi import Depends, HTTPException, APIRouter from qdrant_client import AsyncQdrantClient from .plugins import download_image,read_image from fastapi.responses import JSONResponse from .serializers import ( CreateCollectionSerializer, QueryCollectionSerializer, UpdateCollectionSerializer, DeleteCollectionSerializer ) from model_export.dino_image_matching import get_vectors,get_embedding from .models import CollectionHandler import os app_router = APIRouter() import logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) import pandas as pd @app_router.get("/get_vectors") async def get_vectors_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], image_path:str=os.getenv("DATASET") ): try: # Construct path relative to this file base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx") df = pd.read_excel(excel_path) asin_list = df['ASIN'].dropna().astype(str).tolist() log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}") # 1. Initialize/Create collection "Product" # DINOv2 vitb14 vector size is 768 init_handler = CollectionHandler( collection_name="Product", vector=[], vector_size=768, payload={}, client=q ) await init_handler.create_collection() result_lst = [] for index, row in df.iterrows(): asin = str(row['ASIN']) if pd.isna(row['ASIN']): continue title = str(row['Title']) brand = str(row['Brand']) link = str(row['Image']) # Call get_vectors vector = get_vectors(image_input=image_path, item=asin) if vector is None: log.warning(f"Skipping {asin} due to missing image/vector") continue payload = { "asin": asin, "title": title, "brand": brand, "link": link } # 2. Ingest into Qdrant using CollectionHandler # Use the injected client 'q' and convert index to int handler = CollectionHandler( collection_name="Product", vector=vector, vector_size=768, payload=payload, id=int(index), link=link, asin=asin, brand=brand, client=q ) success = await handler.upsert_point() if success: result_lst.append({ "item": asin, "status": "ingested", "payload": payload }) log.info(f"Vector ingested for {asin} (ID: {index})") else: log.error(f"Failed to ingest vector for {asin}") return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app_router.post("/create") async def create_collection_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], body: CreateCollectionSerializer = None ): try: if body is None: raise HTTPException(status_code=400, detail="Collection name is required") print("collection_name: ", body.collection_name) handler = CollectionHandler( collection_name=body.collection_name, vector=body.vector, vector_size=body.vector_size, payload=body.payload, id=body.id ) # 1. Create collection result = await handler.create_collection() # 2. Automatically call upsert_point await handler.upsert_point() return JSONResponse(result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app_router.get("/query") async def query_collection_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], body: QueryCollectionSerializer ): try: result = [] if isinstance(body.url, str): # Handle semicolon-separated URLs by taking the first one target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url log.info(f"Querying collection {body.collection_name} with URL: {target_url}") downloaded_image_path = download_image(target_url) query_vector = get_embedding(downloaded_image_path) # get_embedding already returns a flat list of 768 floats handler = CollectionHandler( collection_name=body.collection_name, vector=query_vector, vector_size=len(query_vector), payload={}, id=0, client=q ) search_result = await handler.search( query_vector, score_threshold=body.score_threshold, limit=body.limit ) if search_result: result = [ {"id": p.id, "score": p.score, "payload": p.payload} for p in search_result ] else: result = [] # No match within threshold elif isinstance(body.url, list): result = [] for url in body.url: downloaded_image_path = download_image(url) query_vector = get_embedding(downloaded_image_path) # get_embedding already returns a flat list of 768 floats handler = CollectionHandler( collection_name=body.collection_name, vector=query_vector, vector_size=len(query_vector), payload={}, id=0, client=q ) search_result = await handler.search( query_vector, score_threshold=body.score_threshold, limit=body.limit ) if search_result: result.append([ {"id": p.id, "score": p.score, "payload": p.payload} for p in search_result ]) return JSONResponse({"results": result}) except Exception as e: log.error(f"Query failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app_router.put("/update") async def update_collection_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], body: UpdateCollectionSerializer ): try: handler = CollectionHandler( collection_name=body.collection_name, vector=body.vector, vector_size=len(body.vector), payload=body.payload, id=body.id ) result = await handler.update_collection() return JSONResponse({"status": "success", "result": result}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app_router.delete("/delete") async def delete_collection_endpoint( q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)], body: DeleteCollectionSerializer ): try: handler = CollectionHandler( collection_name=body.collection_name, vector=[], vector_size=0, payload={}, id=0 ) result = await handler.delete_collection() return JSONResponse(result) except Exception as e: raise HTTPException(status_code=500, detail=str(e))