183 lines
6.0 KiB
Python
183 lines
6.0 KiB
Python
from db_setup import get_qdrant_client
|
|
from typing import Annotated
|
|
from fastapi import Depends, HTTPException, APIRouter
|
|
from qdrant_client import AsyncQdrantClient
|
|
from fastapi.responses import JSONResponse
|
|
from .serializers import (
|
|
CreateCollectionSerializer,
|
|
QueryCollectionSerializer,
|
|
UpdateCollectionSerializer,
|
|
DeleteCollectionSerializer
|
|
)
|
|
from model_export.dino_image_matching import get_vectors
|
|
from .models import CollectionHandler
|
|
import os
|
|
app_router = APIRouter()
|
|
import logging
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
log = logging.getLogger(__name__)
|
|
|
|
import pandas as pd
|
|
|
|
@app_router.get("/get_vectors")
|
|
async def get_vectors_endpoint(
|
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
image_path:str=os.getenv("DATASET")
|
|
):
|
|
try:
|
|
# Construct path relative to this file
|
|
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
|
df = pd.read_excel(excel_path)
|
|
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
|
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
|
|
|
# 1. Initialize/Create collection "Product"
|
|
# DINOv2 vitb14 vector size is 768
|
|
init_handler = CollectionHandler(
|
|
collection_name="Product",
|
|
vector=[],
|
|
vector_size=768,
|
|
payload={},
|
|
client=q
|
|
)
|
|
await init_handler.create_collection()
|
|
|
|
result_lst = []
|
|
for index, row in df.iterrows():
|
|
asin = str(row['ASIN'])
|
|
if pd.isna(row['ASIN']):
|
|
continue
|
|
|
|
title = str(row['Title'])
|
|
brand = str(row['Brand'])
|
|
link = str(row['Image'])
|
|
|
|
# Call get_vectors
|
|
vector = get_vectors(image_input=image_path, item=asin)
|
|
|
|
if vector is None:
|
|
log.warning(f"Skipping {asin} due to missing image/vector")
|
|
continue
|
|
|
|
payload = {
|
|
"asin": asin,
|
|
"title": title,
|
|
"brand": brand,
|
|
"link": link
|
|
}
|
|
|
|
# 2. Ingest into Qdrant using CollectionHandler
|
|
# Use the injected client 'q' and convert index to int
|
|
handler = CollectionHandler(
|
|
collection_name="Product",
|
|
vector=vector,
|
|
vector_size=768,
|
|
payload=payload,
|
|
id=int(index),
|
|
link=link,
|
|
asin=asin,
|
|
brand=brand,
|
|
client=q
|
|
)
|
|
success = await handler.upsert_point()
|
|
|
|
if success:
|
|
result_lst.append({
|
|
"item": asin,
|
|
"status": "ingested",
|
|
"payload": payload
|
|
})
|
|
log.info(f"Vector ingested for {asin} (ID: {index})")
|
|
else:
|
|
log.error(f"Failed to ingest vector for {asin}")
|
|
|
|
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app_router.post("/create")
|
|
async def create_collection_endpoint(
|
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
body: CreateCollectionSerializer = None
|
|
):
|
|
try:
|
|
if body is None:
|
|
raise HTTPException(status_code=400, detail="Collection name is required")
|
|
|
|
print("collection_name: ", body.collection_name)
|
|
|
|
handler = CollectionHandler(
|
|
collection_name=body.collection_name,
|
|
vector=body.vector,
|
|
vector_size=body.vector_size,
|
|
payload=body.payload,
|
|
id=body.id
|
|
)
|
|
|
|
# 1. Create collection
|
|
result = await handler.create_collection()
|
|
|
|
# 2. Automatically call upsert_point
|
|
await handler.upsert_point()
|
|
|
|
return JSONResponse(result)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app_router.get("/query")
|
|
async def query_collection_endpoint(
|
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
body: QueryCollectionSerializer
|
|
):
|
|
try:
|
|
handler = CollectionHandler(
|
|
collection_name=body.collection_name,
|
|
vector=body.query_vector,
|
|
vector_size=len(body.query_vector),
|
|
payload={},
|
|
id=0
|
|
)
|
|
result = await handler.search(body.query_vector)
|
|
return JSONResponse({"results": str(result)})
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app_router.put("/update")
|
|
async def update_collection_endpoint(
|
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
body: UpdateCollectionSerializer
|
|
):
|
|
try:
|
|
handler = CollectionHandler(
|
|
collection_name=body.collection_name,
|
|
vector=body.vector,
|
|
vector_size=len(body.vector),
|
|
payload=body.payload,
|
|
id=body.id
|
|
)
|
|
result = await handler.update_collection()
|
|
return JSONResponse({"status": "success", "result": result})
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app_router.delete("/delete")
|
|
async def delete_collection_endpoint(
|
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
body: DeleteCollectionSerializer
|
|
):
|
|
try:
|
|
handler = CollectionHandler(
|
|
collection_name=body.collection_name,
|
|
vector=[],
|
|
vector_size=0,
|
|
payload={},
|
|
id=0
|
|
)
|
|
result = await handler.delete_collection()
|
|
return JSONResponse(result)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|