105 lines
3.8 KiB
Python
105 lines
3.8 KiB
Python
from qdrant_client import AsyncQdrantClient, models
|
|
from qdrant_client.models import PointStruct
|
|
from typing import Dict, Any
|
|
|
|
class CollectionHandler:
|
|
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
|
id: int=None,
|
|
link: str=None,
|
|
asin: str=None,
|
|
category: str=None,
|
|
brand: str=None,
|
|
client: AsyncQdrantClient=None
|
|
):
|
|
self.collection_name = collection_name
|
|
self.vector = vector
|
|
self.id = id
|
|
self.vector_size = vector_size
|
|
self.payload = payload
|
|
self.link = link
|
|
self.asin = asin
|
|
self.category = category
|
|
self.brand = brand
|
|
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
|
|
|
async def create_collection(self):
|
|
try:
|
|
if await self.client.collection_exists(self.collection_name):
|
|
return {"message": "Collection already exists"}
|
|
|
|
await self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
|
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
|
)
|
|
|
|
# Creating payload indexes as per project logic
|
|
|
|
await self.client.create_payload_index(
|
|
collection_name=self.collection_name,
|
|
field_name="link",
|
|
field_schema=models.PayloadSchemaType.KEYWORD
|
|
)
|
|
await self.client.create_payload_index(
|
|
collection_name=self.collection_name,
|
|
field_name="title",
|
|
field_schema=models.PayloadSchemaType.KEYWORD
|
|
)
|
|
|
|
await self.client.create_payload_index(
|
|
collection_name=self.collection_name,
|
|
field_name="brand",
|
|
field_schema=models.PayloadSchemaType.KEYWORD
|
|
)
|
|
await self.client.create_payload_index(
|
|
collection_name=self.collection_name,
|
|
field_name="asin",
|
|
field_schema=models.PayloadSchemaType.KEYWORD
|
|
)
|
|
|
|
|
|
return {"message": f"Collection {self.collection_name} created successfully"}
|
|
except Exception as e:
|
|
return {"message": str(e)}
|
|
|
|
async def insertion(self):
|
|
try:
|
|
await self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=[
|
|
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
|
]
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
# Note: In a real app we'd use a logger here
|
|
print(f"Insertion failed for ID {self.id}: {e}")
|
|
return False
|
|
|
|
async def upsert_point(self):
|
|
return await self.insertion()
|
|
|
|
async def search(self, query_vector):
|
|
try:
|
|
result = await self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=query_vector,
|
|
limit=10
|
|
)
|
|
return result
|
|
except Exception as e:
|
|
print("Search failed: ", e)
|
|
return None
|
|
|
|
async def update_collection(self):
|
|
"""Update is implemented as an upsert of the point data."""
|
|
return await self.upsert_point()
|
|
|
|
async def delete_collection(self):
|
|
try:
|
|
await self.client.delete_collection(collection_name=self.collection_name)
|
|
return {"message": f"Collection {self.collection_name} deleted successfully"}
|
|
except Exception as e:
|
|
return {"message": str(e)}
|
|
|