Compare commits
2 Commits
26b0b7f1a9
...
56f52a4edc
| Author | SHA1 | Date |
|---|---|---|
|
|
56f52a4edc | |
|
|
65cb6680eb |
|
|
@ -4,4 +4,6 @@
|
||||||
*pyc**
|
*pyc**
|
||||||
**pycache**
|
**pycache**
|
||||||
*agent/**
|
*agent/**
|
||||||
**downloaded_images**
|
**downloaded_images**
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,104 +1,104 @@
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
class CollectionHandler:
|
class CollectionHandler:
|
||||||
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
||||||
id: int=None,
|
id: int=None,
|
||||||
link: str=None,
|
link: str=None,
|
||||||
asin: str=None,
|
asin: str=None,
|
||||||
category: str=None,
|
category: str=None,
|
||||||
brand: str=None,
|
brand: str=None,
|
||||||
client: AsyncQdrantClient=None
|
client: AsyncQdrantClient=None
|
||||||
):
|
):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
self.id = id
|
self.id = id
|
||||||
self.vector_size = vector_size
|
self.vector_size = vector_size
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.link = link
|
self.link = link
|
||||||
self.asin = asin
|
self.asin = asin
|
||||||
self.category = category
|
self.category = category
|
||||||
self.brand = brand
|
self.brand = brand
|
||||||
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
||||||
|
|
||||||
async def create_collection(self):
|
async def create_collection(self):
|
||||||
try:
|
try:
|
||||||
if await self.client.collection_exists(self.collection_name):
|
if await self.client.collection_exists(self.collection_name):
|
||||||
return {"message": "Collection already exists"}
|
return {"message": "Collection already exists"}
|
||||||
|
|
||||||
await self.client.create_collection(
|
await self.client.create_collection(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
||||||
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Creating payload indexes as per project logic
|
# Creating payload indexes as per project logic
|
||||||
|
|
||||||
await self.client.create_payload_index(
|
await self.client.create_payload_index(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
field_name="link",
|
field_name="link",
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
)
|
)
|
||||||
await self.client.create_payload_index(
|
await self.client.create_payload_index(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
field_name="title",
|
field_name="title",
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.client.create_payload_index(
|
await self.client.create_payload_index(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
field_name="brand",
|
field_name="brand",
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
)
|
)
|
||||||
await self.client.create_payload_index(
|
await self.client.create_payload_index(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
field_name="asin",
|
field_name="asin",
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return {"message": f"Collection {self.collection_name} created successfully"}
|
return {"message": f"Collection {self.collection_name} created successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"message": str(e)}
|
return {"message": str(e)}
|
||||||
|
|
||||||
async def insertion(self):
|
async def insertion(self):
|
||||||
try:
|
try:
|
||||||
await self.client.upsert(
|
await self.client.upsert(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
points=[
|
points=[
|
||||||
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Note: In a real app we'd use a logger here
|
# Note: In a real app we'd use a logger here
|
||||||
print(f"Insertion failed for ID {self.id}: {e}")
|
print(f"Insertion failed for ID {self.id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def upsert_point(self):
|
async def upsert_point(self):
|
||||||
return await self.insertion()
|
return await self.insertion()
|
||||||
|
|
||||||
async def search(self, query_vector):
|
async def search(self, query_vector):
|
||||||
try:
|
try:
|
||||||
result = await self.client.search(
|
result = await self.client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Search failed: ", e)
|
print("Search failed: ", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_collection(self):
|
async def update_collection(self):
|
||||||
"""Update is implemented as an upsert of the point data."""
|
"""Update is implemented as an upsert of the point data."""
|
||||||
return await self.upsert_point()
|
return await self.upsert_point()
|
||||||
|
|
||||||
async def delete_collection(self):
|
async def delete_collection(self):
|
||||||
try:
|
try:
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
await self.client.delete_collection(collection_name=self.collection_name)
|
||||||
return {"message": f"Collection {self.collection_name} deleted successfully"}
|
return {"message": f"Collection {self.collection_name} deleted successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"message": str(e)}
|
return {"message": str(e)}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,22 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
class CreateCollectionSerializer(BaseModel):
|
class CreateCollectionSerializer(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
vector: List[float]
|
vector: List[float]
|
||||||
vector_size: int
|
vector_size: int
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
class QueryCollectionSerializer(BaseModel):
|
class QueryCollectionSerializer(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
query_vector: List[float]
|
query_vector: List[float]
|
||||||
|
|
||||||
class UpdateCollectionSerializer(BaseModel):
|
class UpdateCollectionSerializer(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
vector: List[float]
|
vector: List[float]
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
class DeleteCollectionSerializer(BaseModel):
|
class DeleteCollectionSerializer(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
|
|
|
||||||
|
|
@ -1,183 +1,183 @@
|
||||||
from db_setup import get_qdrant_client
|
from db_setup import get_qdrant_client
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import Depends, HTTPException, APIRouter
|
from fastapi import Depends, HTTPException, APIRouter
|
||||||
from qdrant_client import AsyncQdrantClient
|
from qdrant_client import AsyncQdrantClient
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from .serializers import (
|
from .serializers import (
|
||||||
CreateCollectionSerializer,
|
CreateCollectionSerializer,
|
||||||
QueryCollectionSerializer,
|
QueryCollectionSerializer,
|
||||||
UpdateCollectionSerializer,
|
UpdateCollectionSerializer,
|
||||||
DeleteCollectionSerializer
|
DeleteCollectionSerializer
|
||||||
)
|
)
|
||||||
from model_export.dino_image_matching import get_vectors
|
from model_export.dino_image_matching import get_vectors
|
||||||
from .models import CollectionHandler
|
from .models import CollectionHandler
|
||||||
import os
|
import os
|
||||||
app_router = APIRouter()
|
app_router = APIRouter()
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@app_router.get("/get_vectors")
|
@app_router.get("/get_vectors")
|
||||||
async def get_vectors_endpoint(
|
async def get_vectors_endpoint(
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||||
image_path:str=os.getenv("DATASET")
|
image_path:str=os.getenv("DATASET")
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Construct path relative to this file
|
# Construct path relative to this file
|
||||||
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
||||||
df = pd.read_excel(excel_path)
|
df = pd.read_excel(excel_path)
|
||||||
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
||||||
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
||||||
|
|
||||||
# 1. Initialize/Create collection "Product"
|
# 1. Initialize/Create collection "Product"
|
||||||
# DINOv2 vitb14 vector size is 768
|
# DINOv2 vitb14 vector size is 768
|
||||||
init_handler = CollectionHandler(
|
init_handler = CollectionHandler(
|
||||||
collection_name="Product",
|
collection_name="Product",
|
||||||
vector=[],
|
vector=[],
|
||||||
vector_size=768,
|
vector_size=768,
|
||||||
payload={},
|
payload={},
|
||||||
client=q
|
client=q
|
||||||
)
|
)
|
||||||
await init_handler.create_collection()
|
await init_handler.create_collection()
|
||||||
|
|
||||||
result_lst = []
|
result_lst = []
|
||||||
for index, row in df.iterrows():
|
for index, row in df.iterrows():
|
||||||
asin = str(row['ASIN'])
|
asin = str(row['ASIN'])
|
||||||
if pd.isna(row['ASIN']):
|
if pd.isna(row['ASIN']):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
title = str(row['Title'])
|
title = str(row['Title'])
|
||||||
brand = str(row['Brand'])
|
brand = str(row['Brand'])
|
||||||
link = str(row['Image'])
|
link = str(row['Image'])
|
||||||
|
|
||||||
# Call get_vectors
|
# Call get_vectors
|
||||||
vector = get_vectors(image_input=image_path, item=asin)
|
vector = get_vectors(image_input=image_path, item=asin)
|
||||||
|
|
||||||
if vector is None:
|
if vector is None:
|
||||||
log.warning(f"Skipping {asin} due to missing image/vector")
|
log.warning(f"Skipping {asin} due to missing image/vector")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"asin": asin,
|
"asin": asin,
|
||||||
"title": title,
|
"title": title,
|
||||||
"brand": brand,
|
"brand": brand,
|
||||||
"link": link
|
"link": link
|
||||||
}
|
}
|
||||||
|
|
||||||
# 2. Ingest into Qdrant using CollectionHandler
|
# 2. Ingest into Qdrant using CollectionHandler
|
||||||
# Use the injected client 'q' and convert index to int
|
# Use the injected client 'q' and convert index to int
|
||||||
handler = CollectionHandler(
|
handler = CollectionHandler(
|
||||||
collection_name="Product",
|
collection_name="Product",
|
||||||
vector=vector,
|
vector=vector,
|
||||||
vector_size=768,
|
vector_size=768,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
id=int(index),
|
id=int(index),
|
||||||
link=link,
|
link=link,
|
||||||
asin=asin,
|
asin=asin,
|
||||||
brand=brand,
|
brand=brand,
|
||||||
client=q
|
client=q
|
||||||
)
|
)
|
||||||
success = await handler.upsert_point()
|
success = await handler.upsert_point()
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
result_lst.append({
|
result_lst.append({
|
||||||
"item": asin,
|
"item": asin,
|
||||||
"status": "ingested",
|
"status": "ingested",
|
||||||
"payload": payload
|
"payload": payload
|
||||||
})
|
})
|
||||||
log.info(f"Vector ingested for {asin} (ID: {index})")
|
log.info(f"Vector ingested for {asin} (ID: {index})")
|
||||||
else:
|
else:
|
||||||
log.error(f"Failed to ingest vector for {asin}")
|
log.error(f"Failed to ingest vector for {asin}")
|
||||||
|
|
||||||
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app_router.post("/create")
|
@app_router.post("/create")
|
||||||
async def create_collection_endpoint(
|
async def create_collection_endpoint(
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||||
body: CreateCollectionSerializer = None
|
body: CreateCollectionSerializer = None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if body is None:
|
if body is None:
|
||||||
raise HTTPException(status_code=400, detail="Collection name is required")
|
raise HTTPException(status_code=400, detail="Collection name is required")
|
||||||
|
|
||||||
print("collection_name: ", body.collection_name)
|
print("collection_name: ", body.collection_name)
|
||||||
|
|
||||||
handler = CollectionHandler(
|
handler = CollectionHandler(
|
||||||
collection_name=body.collection_name,
|
collection_name=body.collection_name,
|
||||||
vector=body.vector,
|
vector=body.vector,
|
||||||
vector_size=body.vector_size,
|
vector_size=body.vector_size,
|
||||||
payload=body.payload,
|
payload=body.payload,
|
||||||
id=body.id
|
id=body.id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Create collection
|
# 1. Create collection
|
||||||
result = await handler.create_collection()
|
result = await handler.create_collection()
|
||||||
|
|
||||||
# 2. Automatically call upsert_point
|
# 2. Automatically call upsert_point
|
||||||
await handler.upsert_point()
|
await handler.upsert_point()
|
||||||
|
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app_router.get("/query")
|
@app_router.get("/query")
|
||||||
async def query_collection_endpoint(
|
async def query_collection_endpoint(
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||||
body: QueryCollectionSerializer
|
body: QueryCollectionSerializer
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
handler = CollectionHandler(
|
handler = CollectionHandler(
|
||||||
collection_name=body.collection_name,
|
collection_name=body.collection_name,
|
||||||
vector=body.query_vector,
|
vector=body.query_vector,
|
||||||
vector_size=len(body.query_vector),
|
vector_size=len(body.query_vector),
|
||||||
payload={},
|
payload={},
|
||||||
id=0
|
id=0
|
||||||
)
|
)
|
||||||
result = await handler.search(body.query_vector)
|
result = await handler.search(body.query_vector)
|
||||||
return JSONResponse({"results": str(result)})
|
return JSONResponse({"results": str(result)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app_router.put("/update")
|
@app_router.put("/update")
|
||||||
async def update_collection_endpoint(
|
async def update_collection_endpoint(
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||||
body: UpdateCollectionSerializer
|
body: UpdateCollectionSerializer
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
handler = CollectionHandler(
|
handler = CollectionHandler(
|
||||||
collection_name=body.collection_name,
|
collection_name=body.collection_name,
|
||||||
vector=body.vector,
|
vector=body.vector,
|
||||||
vector_size=len(body.vector),
|
vector_size=len(body.vector),
|
||||||
payload=body.payload,
|
payload=body.payload,
|
||||||
id=body.id
|
id=body.id
|
||||||
)
|
)
|
||||||
result = await handler.update_collection()
|
result = await handler.update_collection()
|
||||||
return JSONResponse({"status": "success", "result": result})
|
return JSONResponse({"status": "success", "result": result})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app_router.delete("/delete")
|
@app_router.delete("/delete")
|
||||||
async def delete_collection_endpoint(
|
async def delete_collection_endpoint(
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||||
body: DeleteCollectionSerializer
|
body: DeleteCollectionSerializer
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
handler = CollectionHandler(
|
handler = CollectionHandler(
|
||||||
collection_name=body.collection_name,
|
collection_name=body.collection_name,
|
||||||
vector=[],
|
vector=[],
|
||||||
vector_size=0,
|
vector_size=0,
|
||||||
payload={},
|
payload={},
|
||||||
id=0
|
id=0
|
||||||
)
|
)
|
||||||
result = await handler.delete_collection()
|
result = await handler.delete_collection()
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue