Compare commits
2 Commits
56f52a4edc
...
6ed79621b0
| Author | SHA1 | Date |
|---|---|---|
|
|
6ed79621b0 | |
|
|
0fe92f182f |
|
|
@ -5,5 +5,5 @@
|
|||
**pycache**
|
||||
*agent/**
|
||||
**downloaded_images**
|
||||
__pycache__/
|
||||
*.pyc
|
||||
**model_export**
|
||||
**cpython**
|
||||
|
|
@ -1,104 +1,108 @@
|
|||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
from typing import Dict, Any
|
||||
|
||||
class CollectionHandler:
|
||||
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
||||
id: int=None,
|
||||
link: str=None,
|
||||
asin: str=None,
|
||||
category: str=None,
|
||||
brand: str=None,
|
||||
client: AsyncQdrantClient=None
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.vector = vector
|
||||
self.id = id
|
||||
self.vector_size = vector_size
|
||||
self.payload = payload
|
||||
self.link = link
|
||||
self.asin = asin
|
||||
self.category = category
|
||||
self.brand = brand
|
||||
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
||||
|
||||
async def create_collection(self):
|
||||
try:
|
||||
if await self.client.collection_exists(self.collection_name):
|
||||
return {"message": "Collection already exists"}
|
||||
|
||||
await self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
||||
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
||||
)
|
||||
|
||||
# Creating payload indexes as per project logic
|
||||
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="link",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="title",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="brand",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="asin",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
|
||||
|
||||
return {"message": f"Collection {self.collection_name} created successfully"}
|
||||
except Exception as e:
|
||||
return {"message": str(e)}
|
||||
|
||||
async def insertion(self):
|
||||
try:
|
||||
await self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=[
|
||||
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
||||
]
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
# Note: In a real app we'd use a logger here
|
||||
print(f"Insertion failed for ID {self.id}: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_point(self):
|
||||
return await self.insertion()
|
||||
|
||||
async def search(self, query_vector):
|
||||
try:
|
||||
result = await self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=10
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print("Search failed: ", e)
|
||||
return None
|
||||
|
||||
async def update_collection(self):
|
||||
"""Update is implemented as an upsert of the point data."""
|
||||
return await self.upsert_point()
|
||||
|
||||
async def delete_collection(self):
|
||||
try:
|
||||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
return {"message": f"Collection {self.collection_name} deleted successfully"}
|
||||
except Exception as e:
|
||||
return {"message": str(e)}
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
class CollectionHandler:
|
||||
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
||||
id: int=None,
|
||||
link: str=None,
|
||||
asin: str=None,
|
||||
category: str=None,
|
||||
brand: str=None,
|
||||
client: AsyncQdrantClient=None
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.vector = vector
|
||||
self.id = id
|
||||
self.vector_size = vector_size
|
||||
self.payload = payload
|
||||
self.link = link
|
||||
self.asin = asin
|
||||
self.category = category
|
||||
self.brand = brand
|
||||
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
||||
|
||||
async def create_collection(self):
|
||||
try:
|
||||
if await self.client.collection_exists(self.collection_name):
|
||||
return {"message": "Collection already exists"}
|
||||
|
||||
await self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
||||
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
||||
)
|
||||
|
||||
# Creating payload indexes as per project logic
|
||||
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="link",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="title",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="brand",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
await self.client.create_payload_index(
|
||||
collection_name=self.collection_name,
|
||||
field_name="asin",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
|
||||
|
||||
return {"message": f"Collection {self.collection_name} created successfully"}
|
||||
except Exception as e:
|
||||
return {"message": str(e)}
|
||||
|
||||
async def insertion(self):
|
||||
try:
|
||||
await self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=[
|
||||
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
||||
]
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
# Note: In a real app we'd use a logger here
|
||||
print(f"Insertion failed for ID {self.id}: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_point(self):
|
||||
return await self.insertion()
|
||||
|
||||
async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10):
|
||||
try:
|
||||
result = await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
query=query_vector,
|
||||
score_threshold=score_threshold,
|
||||
limit=limit
|
||||
)
|
||||
return result.points
|
||||
except Exception as e:
|
||||
log.error(f"Search failed: {e}")
|
||||
return None
|
||||
|
||||
async def update_collection(self):
|
||||
"""Update is implemented as an upsert of the point data."""
|
||||
return await self.upsert_point()
|
||||
|
||||
async def delete_collection(self):
|
||||
try:
|
||||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
return {"message": f"Collection {self.collection_name} deleted successfully"}
|
||||
except Exception as e:
|
||||
return {"message": str(e)}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
def download_image(url: str, filename: str = None) -> str:
|
||||
"""
|
||||
Download an image from URL and save it in data/temp/ folder.
|
||||
|
||||
Args:
|
||||
url (str): Image URL
|
||||
filename (str, optional): Custom filename. If None, extracted from URL.
|
||||
|
||||
Returns:
|
||||
str: Full path to the downloaded image
|
||||
"""
|
||||
try:
|
||||
# Get project root directory (where your main script is)
|
||||
root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent
|
||||
|
||||
# Create data/temp folder structure
|
||||
temp_dir = root_dir / "data" / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename if not provided
|
||||
if not filename:
|
||||
parsed_url = urlparse(url)
|
||||
filename = os.path.basename(parsed_url.path)
|
||||
if not filename or "." not in filename:
|
||||
# Fallback filename
|
||||
ext = filename.split('.')[-1] if '.' in filename else 'jpg'
|
||||
filename = f"image_{hash(url) % 100000}.{ext}"
|
||||
|
||||
# Ensure filename has extension
|
||||
if '.' not in filename:
|
||||
filename += ".jpg"
|
||||
|
||||
file_path = temp_dir / filename
|
||||
|
||||
# Download the image
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# Save image
|
||||
with open(file_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"✅ Image downloaded: {file_path}")
|
||||
return str(file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to download image: {e}")
|
||||
raise
|
||||
|
||||
def read_image(image_path: str) -> Image.Image:
|
||||
"""
|
||||
Read an image from the given path and return a PIL Image object.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the image file
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image: Loaded image
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If image doesn't exist
|
||||
Exception: For other image loading errors
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"Image not found at path: {image_path}")
|
||||
|
||||
# Open the image
|
||||
image = Image.open(image_path)
|
||||
|
||||
# Convert to RGB (important for DINOv2 and most models)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}")
|
||||
return image
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ File not found: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to read image: {e}")
|
||||
raise
|
||||
|
|
@ -1,22 +1,24 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any
|
||||
|
||||
class CreateCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
vector: List[float]
|
||||
vector_size: int
|
||||
payload: Dict[str, Any]
|
||||
id: int
|
||||
|
||||
class QueryCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
query_vector: List[float]
|
||||
|
||||
class UpdateCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
vector: List[float]
|
||||
payload: Dict[str, Any]
|
||||
id: int
|
||||
|
||||
class DeleteCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Any
|
||||
|
||||
class CreateCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
vector: List[float]
|
||||
vector_size: int
|
||||
payload: Dict[str, Any]
|
||||
id: int
|
||||
|
||||
class QueryCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
url: str
|
||||
score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match
|
||||
limit: int = 10
|
||||
|
||||
class UpdateCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
vector: List[float]
|
||||
payload: Dict[str, Any]
|
||||
id: int
|
||||
|
||||
class DeleteCollectionSerializer(BaseModel):
|
||||
collection_name: str
|
||||
|
|
|
|||
|
|
@ -1,183 +1,233 @@
|
|||
from db_setup import get_qdrant_client
|
||||
from typing import Annotated
|
||||
from fastapi import Depends, HTTPException, APIRouter
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from fastapi.responses import JSONResponse
|
||||
from .serializers import (
|
||||
CreateCollectionSerializer,
|
||||
QueryCollectionSerializer,
|
||||
UpdateCollectionSerializer,
|
||||
DeleteCollectionSerializer
|
||||
)
|
||||
from model_export.dino_image_matching import get_vectors
|
||||
from .models import CollectionHandler
|
||||
import os
|
||||
app_router = APIRouter()
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@app_router.get("/get_vectors")
|
||||
async def get_vectors_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
image_path:str=os.getenv("DATASET")
|
||||
):
|
||||
try:
|
||||
# Construct path relative to this file
|
||||
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
||||
df = pd.read_excel(excel_path)
|
||||
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
||||
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
||||
|
||||
# 1. Initialize/Create collection "Product"
|
||||
# DINOv2 vitb14 vector size is 768
|
||||
init_handler = CollectionHandler(
|
||||
collection_name="Product",
|
||||
vector=[],
|
||||
vector_size=768,
|
||||
payload={},
|
||||
client=q
|
||||
)
|
||||
await init_handler.create_collection()
|
||||
|
||||
result_lst = []
|
||||
for index, row in df.iterrows():
|
||||
asin = str(row['ASIN'])
|
||||
if pd.isna(row['ASIN']):
|
||||
continue
|
||||
|
||||
title = str(row['Title'])
|
||||
brand = str(row['Brand'])
|
||||
link = str(row['Image'])
|
||||
|
||||
# Call get_vectors
|
||||
vector = get_vectors(image_input=image_path, item=asin)
|
||||
|
||||
if vector is None:
|
||||
log.warning(f"Skipping {asin} due to missing image/vector")
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"asin": asin,
|
||||
"title": title,
|
||||
"brand": brand,
|
||||
"link": link
|
||||
}
|
||||
|
||||
# 2. Ingest into Qdrant using CollectionHandler
|
||||
# Use the injected client 'q' and convert index to int
|
||||
handler = CollectionHandler(
|
||||
collection_name="Product",
|
||||
vector=vector,
|
||||
vector_size=768,
|
||||
payload=payload,
|
||||
id=int(index),
|
||||
link=link,
|
||||
asin=asin,
|
||||
brand=brand,
|
||||
client=q
|
||||
)
|
||||
success = await handler.upsert_point()
|
||||
|
||||
if success:
|
||||
result_lst.append({
|
||||
"item": asin,
|
||||
"status": "ingested",
|
||||
"payload": payload
|
||||
})
|
||||
log.info(f"Vector ingested for {asin} (ID: {index})")
|
||||
else:
|
||||
log.error(f"Failed to ingest vector for {asin}")
|
||||
|
||||
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.post("/create")
|
||||
async def create_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: CreateCollectionSerializer = None
|
||||
):
|
||||
try:
|
||||
if body is None:
|
||||
raise HTTPException(status_code=400, detail="Collection name is required")
|
||||
|
||||
print("collection_name: ", body.collection_name)
|
||||
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=body.vector,
|
||||
vector_size=body.vector_size,
|
||||
payload=body.payload,
|
||||
id=body.id
|
||||
)
|
||||
|
||||
# 1. Create collection
|
||||
result = await handler.create_collection()
|
||||
|
||||
# 2. Automatically call upsert_point
|
||||
await handler.upsert_point()
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.get("/query")
|
||||
async def query_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: QueryCollectionSerializer
|
||||
):
|
||||
try:
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=body.query_vector,
|
||||
vector_size=len(body.query_vector),
|
||||
payload={},
|
||||
id=0
|
||||
)
|
||||
result = await handler.search(body.query_vector)
|
||||
return JSONResponse({"results": str(result)})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.put("/update")
|
||||
async def update_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: UpdateCollectionSerializer
|
||||
):
|
||||
try:
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=body.vector,
|
||||
vector_size=len(body.vector),
|
||||
payload=body.payload,
|
||||
id=body.id
|
||||
)
|
||||
result = await handler.update_collection()
|
||||
return JSONResponse({"status": "success", "result": result})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.delete("/delete")
|
||||
async def delete_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: DeleteCollectionSerializer
|
||||
):
|
||||
try:
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=[],
|
||||
vector_size=0,
|
||||
payload={},
|
||||
id=0
|
||||
)
|
||||
result = await handler.delete_collection()
|
||||
return JSONResponse(result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
from db_setup import get_qdrant_client
|
||||
from typing import Annotated
|
||||
from fastapi import Depends, HTTPException, APIRouter
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from .plugins import download_image,read_image
|
||||
from fastapi.responses import JSONResponse
|
||||
from .serializers import (
|
||||
CreateCollectionSerializer,
|
||||
QueryCollectionSerializer,
|
||||
UpdateCollectionSerializer,
|
||||
DeleteCollectionSerializer
|
||||
)
|
||||
from model_export.dino_image_matching import get_vectors,get_embedding
|
||||
from .models import CollectionHandler
|
||||
import os
|
||||
app_router = APIRouter()
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@app_router.get("/get_vectors")
|
||||
async def get_vectors_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
image_path:str=os.getenv("DATASET")
|
||||
):
|
||||
try:
|
||||
# Construct path relative to this file
|
||||
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
||||
df = pd.read_excel(excel_path)
|
||||
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
||||
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
||||
|
||||
# 1. Initialize/Create collection "Product"
|
||||
# DINOv2 vitb14 vector size is 768
|
||||
init_handler = CollectionHandler(
|
||||
collection_name="Product",
|
||||
vector=[],
|
||||
vector_size=768,
|
||||
payload={},
|
||||
client=q
|
||||
)
|
||||
await init_handler.create_collection()
|
||||
|
||||
result_lst = []
|
||||
for index, row in df.iterrows():
|
||||
asin = str(row['ASIN'])
|
||||
if pd.isna(row['ASIN']):
|
||||
continue
|
||||
|
||||
title = str(row['Title'])
|
||||
brand = str(row['Brand'])
|
||||
link = str(row['Image'])
|
||||
|
||||
# Call get_vectors
|
||||
vector = get_vectors(image_input=image_path, item=asin)
|
||||
|
||||
if vector is None:
|
||||
log.warning(f"Skipping {asin} due to missing image/vector")
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"asin": asin,
|
||||
"title": title,
|
||||
"brand": brand,
|
||||
"link": link
|
||||
}
|
||||
|
||||
# 2. Ingest into Qdrant using CollectionHandler
|
||||
# Use the injected client 'q' and convert index to int
|
||||
handler = CollectionHandler(
|
||||
collection_name="Product",
|
||||
vector=vector,
|
||||
vector_size=768,
|
||||
payload=payload,
|
||||
id=int(index),
|
||||
link=link,
|
||||
asin=asin,
|
||||
brand=brand,
|
||||
client=q
|
||||
)
|
||||
success = await handler.upsert_point()
|
||||
|
||||
if success:
|
||||
result_lst.append({
|
||||
"item": asin,
|
||||
"status": "ingested",
|
||||
"payload": payload
|
||||
})
|
||||
log.info(f"Vector ingested for {asin} (ID: {index})")
|
||||
else:
|
||||
log.error(f"Failed to ingest vector for {asin}")
|
||||
|
||||
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.post("/create")
|
||||
async def create_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: CreateCollectionSerializer = None
|
||||
):
|
||||
try:
|
||||
if body is None:
|
||||
raise HTTPException(status_code=400, detail="Collection name is required")
|
||||
|
||||
print("collection_name: ", body.collection_name)
|
||||
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=body.vector,
|
||||
vector_size=body.vector_size,
|
||||
payload=body.payload,
|
||||
id=body.id
|
||||
)
|
||||
|
||||
# 1. Create collection
|
||||
result = await handler.create_collection()
|
||||
|
||||
# 2. Automatically call upsert_point
|
||||
await handler.upsert_point()
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.get("/query")
|
||||
async def query_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: QueryCollectionSerializer
|
||||
):
|
||||
try:
|
||||
result = []
|
||||
if isinstance(body.url, str):
|
||||
# Handle semicolon-separated URLs by taking the first one
|
||||
target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url
|
||||
log.info(f"Querying collection {body.collection_name} with URL: {target_url}")
|
||||
downloaded_image_path = download_image(target_url)
|
||||
query_vector = get_embedding(downloaded_image_path)
|
||||
# get_embedding already returns a flat list of 768 floats
|
||||
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=query_vector,
|
||||
vector_size=len(query_vector),
|
||||
payload={},
|
||||
id=0,
|
||||
client=q
|
||||
)
|
||||
search_result = await handler.search(
|
||||
query_vector,
|
||||
score_threshold=body.score_threshold,
|
||||
limit=body.limit
|
||||
)
|
||||
if search_result:
|
||||
result = [
|
||||
{"id": p.id, "score": p.score, "payload": p.payload}
|
||||
for p in search_result
|
||||
]
|
||||
else:
|
||||
result = [] # No match within threshold
|
||||
|
||||
elif isinstance(body.url, list):
|
||||
result = []
|
||||
for url in body.url:
|
||||
downloaded_image_path = download_image(url)
|
||||
query_vector = get_embedding(downloaded_image_path)
|
||||
# get_embedding already returns a flat list of 768 floats
|
||||
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=query_vector,
|
||||
vector_size=len(query_vector),
|
||||
payload={},
|
||||
id=0,
|
||||
client=q
|
||||
)
|
||||
search_result = await handler.search(
|
||||
query_vector,
|
||||
score_threshold=body.score_threshold,
|
||||
limit=body.limit
|
||||
)
|
||||
if search_result:
|
||||
result.append([
|
||||
{"id": p.id, "score": p.score, "payload": p.payload}
|
||||
for p in search_result
|
||||
])
|
||||
|
||||
return JSONResponse({"results": result})
|
||||
except Exception as e:
|
||||
log.error(f"Query failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.put("/update")
|
||||
async def update_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: UpdateCollectionSerializer
|
||||
):
|
||||
try:
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=body.vector,
|
||||
vector_size=len(body.vector),
|
||||
payload=body.payload,
|
||||
id=body.id
|
||||
)
|
||||
result = await handler.update_collection()
|
||||
return JSONResponse({"status": "success", "result": result})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app_router.delete("/delete")
|
||||
async def delete_collection_endpoint(
|
||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||
body: DeleteCollectionSerializer
|
||||
):
|
||||
try:
|
||||
handler = CollectionHandler(
|
||||
collection_name=body.collection_name,
|
||||
vector=[],
|
||||
vector_size=0,
|
||||
payload={},
|
||||
id=0
|
||||
)
|
||||
result = await handler.delete_collection()
|
||||
return JSONResponse(result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
|
@ -46,7 +46,8 @@ def get_embedding(image_path):
|
|||
# Normalize embedding (important for cosine similarity)
|
||||
embedding = F.normalize(embedding, p=2, dim=1)
|
||||
|
||||
return embedding.cpu()
|
||||
# Return flat list (squeeze batch dim)
|
||||
return embedding.squeeze(0).cpu().tolist()
|
||||
|
||||
def get_vectors(image_input, item):
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue