Compare commits
No commits in common. "0fe92f182fc5007e1eecae54950bb126d2610028" and "ddc80f04acfb98a43ebb2d8775cc1d7a70cd2cb5" have entirely different histories.
0fe92f182f
...
ddc80f04ac
|
|
@ -1,9 +0,0 @@
|
||||||
.env
|
|
||||||
**venv**
|
|
||||||
**data**
|
|
||||||
*pyc**
|
|
||||||
**pycache**
|
|
||||||
*agent/**
|
|
||||||
**downloaded_images**
|
|
||||||
**model_export**
|
|
||||||
**cpython**
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
import os
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
|
||||||
from typing import Annotated
|
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
# pyrefly: ignore [missing-import]
|
|
||||||
from fastapi import FastAPI, Depends
|
|
||||||
from sqlalchemy.engine.url import URL
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
DATABASE_URL = URL.create(
|
|
||||||
drivername="mysql+asyncmy",
|
|
||||||
username=os.getenv("MYSQL_USER"),
|
|
||||||
password=os.getenv("MYSQL_PASSWORD"),
|
|
||||||
host=os.getenv("MYSQL_HOST"),
|
|
||||||
port=os.getenv("MYSQL_PORT"),
|
|
||||||
database=os.getenv("MYSQL_DATABASE"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_qdrant_client()->AsyncGenerator[AsyncQdrantClient,None]:
|
|
||||||
# Replace with your Qdrant URL
|
|
||||||
client = AsyncQdrantClient(url="http://localhost:6333", timeout=60)
|
|
||||||
try:
|
|
||||||
yield client
|
|
||||||
finally:
|
|
||||||
# Properly close the async client
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
async def get_session():
|
|
||||||
engine = create_async_engine(DATABASE_URL, echo=True)
|
|
||||||
async_session = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
|
|
||||||
session = async_session()
|
|
||||||
Base = declarative_base()
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add the project root to sys.path to allow imports from model_export
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from fastapi import FastAPI, status
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from db_setup import get_qdrant_client,get_session
|
|
||||||
from mysql_process.views import app_router as mysql_router
|
|
||||||
from vector_db_router.views import app_router as vector_db_router
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
api = FastAPI(
|
|
||||||
docs_url="/docs",
|
|
||||||
redoc_url="/redocs",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.include_router(mysql_router,prefix="/mysql",tags=["mysql_process"])
|
|
||||||
api.include_router(vector_db_router,prefix="/collection",tags=["vector_db"])
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
from fastapi import FastAPI
|
|
||||||
from sqlmodel import SQLModel, Field
|
|
||||||
|
|
||||||
class Memory(SQLModel,table=True):
|
|
||||||
id: int = Field(default=None,primary_key=True)
|
|
||||||
product_link: str = Field(default=None,index=True)
|
|
||||||
price: float = Field(default=None)
|
|
||||||
product_image: str = Field(default=None)
|
|
||||||
product_name: str = Field(default=None)
|
|
||||||
product_description: str = Field(default=None)
|
|
||||||
product_rating: float = Field(default=None)
|
|
||||||
product_review: str = Field(default=None)
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
from fastapi import FastAPI
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from .models import Memory
|
|
||||||
|
|
||||||
class MemorySerializer(BaseModel):
|
|
||||||
id: int
|
|
||||||
product_link: str
|
|
||||||
price: float
|
|
||||||
product_image: str
|
|
||||||
product_name: str
|
|
||||||
product_description: str
|
|
||||||
product_rating: float
|
|
||||||
product_review: str
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
orm_mode = True
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
from typing import Annotated
|
|
||||||
from fastapi import Depends, Header, HTTPException,APIRouter
|
|
||||||
from typing import List,Optional
|
|
||||||
from db_setup import get_session,get_qdrant_client
|
|
||||||
# from .models import Product
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from qdrant_client import AsyncQdrantClient
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
app_router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
@app_router.get("/products")
|
|
||||||
async def get_all_products(
|
|
||||||
session: Annotated[AsyncSession, Depends(get_session)],
|
|
||||||
vector_db:Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
return JSONResponse(content={"message": "Hello World"})
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
@ -1,108 +0,0 @@
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
|
||||||
from qdrant_client.models import PointStruct
|
|
||||||
from typing import Dict, Any
|
|
||||||
import logging
|
|
||||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class CollectionHandler:
|
|
||||||
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
|
||||||
id: int=None,
|
|
||||||
link: str=None,
|
|
||||||
asin: str=None,
|
|
||||||
category: str=None,
|
|
||||||
brand: str=None,
|
|
||||||
client: AsyncQdrantClient=None
|
|
||||||
):
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.vector = vector
|
|
||||||
self.id = id
|
|
||||||
self.vector_size = vector_size
|
|
||||||
self.payload = payload
|
|
||||||
self.link = link
|
|
||||||
self.asin = asin
|
|
||||||
self.category = category
|
|
||||||
self.brand = brand
|
|
||||||
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
|
||||||
|
|
||||||
async def create_collection(self):
|
|
||||||
try:
|
|
||||||
if await self.client.collection_exists(self.collection_name):
|
|
||||||
return {"message": "Collection already exists"}
|
|
||||||
|
|
||||||
await self.client.create_collection(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
|
||||||
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Creating payload indexes as per project logic
|
|
||||||
|
|
||||||
await self.client.create_payload_index(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
field_name="link",
|
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
|
||||||
)
|
|
||||||
await self.client.create_payload_index(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
field_name="title",
|
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.client.create_payload_index(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
field_name="brand",
|
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
|
||||||
)
|
|
||||||
await self.client.create_payload_index(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
field_name="asin",
|
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
return {"message": f"Collection {self.collection_name} created successfully"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"message": str(e)}
|
|
||||||
|
|
||||||
async def insertion(self):
|
|
||||||
try:
|
|
||||||
await self.client.upsert(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
points=[
|
|
||||||
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
# Note: In a real app we'd use a logger here
|
|
||||||
print(f"Insertion failed for ID {self.id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def upsert_point(self):
|
|
||||||
return await self.insertion()
|
|
||||||
|
|
||||||
async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10):
|
|
||||||
try:
|
|
||||||
result = await self.client.query_points(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
query=query_vector,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
return result.points
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Search failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def update_collection(self):
|
|
||||||
"""Update is implemented as an upsert of the point data."""
|
|
||||||
return await self.upsert_point()
|
|
||||||
|
|
||||||
async def delete_collection(self):
|
|
||||||
try:
|
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
|
||||||
return {"message": f"Collection {self.collection_name} deleted successfully"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"message": str(e)}
|
|
||||||
|
|
||||||
|
|
@ -1,90 +0,0 @@
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
def download_image(url: str, filename: str = None) -> str:
|
|
||||||
"""
|
|
||||||
Download an image from URL and save it in data/temp/ folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): Image URL
|
|
||||||
filename (str, optional): Custom filename. If None, extracted from URL.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Full path to the downloaded image
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get project root directory (where your main script is)
|
|
||||||
root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent
|
|
||||||
|
|
||||||
# Create data/temp folder structure
|
|
||||||
temp_dir = root_dir / "data" / "temp"
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Generate filename if not provided
|
|
||||||
if not filename:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
filename = os.path.basename(parsed_url.path)
|
|
||||||
if not filename or "." not in filename:
|
|
||||||
# Fallback filename
|
|
||||||
ext = filename.split('.')[-1] if '.' in filename else 'jpg'
|
|
||||||
filename = f"image_{hash(url) % 100000}.{ext}"
|
|
||||||
|
|
||||||
# Ensure filename has extension
|
|
||||||
if '.' not in filename:
|
|
||||||
filename += ".jpg"
|
|
||||||
|
|
||||||
file_path = temp_dir / filename
|
|
||||||
|
|
||||||
# Download the image
|
|
||||||
response = requests.get(url, stream=True, timeout=30)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# Save image
|
|
||||||
with open(file_path, 'wb') as f:
|
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
print(f"✅ Image downloaded: {file_path}")
|
|
||||||
return str(file_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to download image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def read_image(image_path: str) -> Image.Image:
|
|
||||||
"""
|
|
||||||
Read an image from the given path and return a PIL Image object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_path (str): Path to the image file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL.Image.Image: Loaded image
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If image doesn't exist
|
|
||||||
Exception: For other image loading errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
raise FileNotFoundError(f"Image not found at path: {image_path}")
|
|
||||||
|
|
||||||
# Open the image
|
|
||||||
image = Image.open(image_path)
|
|
||||||
|
|
||||||
# Convert to RGB (important for DINOv2 and most models)
|
|
||||||
if image.mode != "RGB":
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}")
|
|
||||||
return image
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
print(f"❌ File not found: {e}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to read image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Dict, List, Any
|
|
||||||
|
|
||||||
class CreateCollectionSerializer(BaseModel):
|
|
||||||
collection_name: str
|
|
||||||
vector: List[float]
|
|
||||||
vector_size: int
|
|
||||||
payload: Dict[str, Any]
|
|
||||||
id: int
|
|
||||||
|
|
||||||
class QueryCollectionSerializer(BaseModel):
|
|
||||||
collection_name: str
|
|
||||||
url: str
|
|
||||||
score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match
|
|
||||||
limit: int = 10
|
|
||||||
|
|
||||||
class UpdateCollectionSerializer(BaseModel):
|
|
||||||
collection_name: str
|
|
||||||
vector: List[float]
|
|
||||||
payload: Dict[str, Any]
|
|
||||||
id: int
|
|
||||||
|
|
||||||
class DeleteCollectionSerializer(BaseModel):
|
|
||||||
collection_name: str
|
|
||||||
|
|
@ -1,233 +0,0 @@
|
||||||
from db_setup import get_qdrant_client
|
|
||||||
from typing import Annotated
|
|
||||||
from fastapi import Depends, HTTPException, APIRouter
|
|
||||||
from qdrant_client import AsyncQdrantClient
|
|
||||||
from .plugins import download_image,read_image
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from .serializers import (
|
|
||||||
CreateCollectionSerializer,
|
|
||||||
QueryCollectionSerializer,
|
|
||||||
UpdateCollectionSerializer,
|
|
||||||
DeleteCollectionSerializer
|
|
||||||
)
|
|
||||||
from model_export.dino_image_matching import get_vectors,get_embedding
|
|
||||||
from .models import CollectionHandler
|
|
||||||
import os
|
|
||||||
app_router = APIRouter()
|
|
||||||
import logging
|
|
||||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
@app_router.get("/get_vectors")
|
|
||||||
async def get_vectors_endpoint(
|
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
||||||
image_path:str=os.getenv("DATASET")
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
# Construct path relative to this file
|
|
||||||
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
|
||||||
df = pd.read_excel(excel_path)
|
|
||||||
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
|
||||||
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
|
||||||
|
|
||||||
# 1. Initialize/Create collection "Product"
|
|
||||||
# DINOv2 vitb14 vector size is 768
|
|
||||||
init_handler = CollectionHandler(
|
|
||||||
collection_name="Product",
|
|
||||||
vector=[],
|
|
||||||
vector_size=768,
|
|
||||||
payload={},
|
|
||||||
client=q
|
|
||||||
)
|
|
||||||
await init_handler.create_collection()
|
|
||||||
|
|
||||||
result_lst = []
|
|
||||||
for index, row in df.iterrows():
|
|
||||||
asin = str(row['ASIN'])
|
|
||||||
if pd.isna(row['ASIN']):
|
|
||||||
continue
|
|
||||||
|
|
||||||
title = str(row['Title'])
|
|
||||||
brand = str(row['Brand'])
|
|
||||||
link = str(row['Image'])
|
|
||||||
|
|
||||||
# Call get_vectors
|
|
||||||
vector = get_vectors(image_input=image_path, item=asin)
|
|
||||||
|
|
||||||
if vector is None:
|
|
||||||
log.warning(f"Skipping {asin} due to missing image/vector")
|
|
||||||
continue
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"asin": asin,
|
|
||||||
"title": title,
|
|
||||||
"brand": brand,
|
|
||||||
"link": link
|
|
||||||
}
|
|
||||||
|
|
||||||
# 2. Ingest into Qdrant using CollectionHandler
|
|
||||||
# Use the injected client 'q' and convert index to int
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name="Product",
|
|
||||||
vector=vector,
|
|
||||||
vector_size=768,
|
|
||||||
payload=payload,
|
|
||||||
id=int(index),
|
|
||||||
link=link,
|
|
||||||
asin=asin,
|
|
||||||
brand=brand,
|
|
||||||
client=q
|
|
||||||
)
|
|
||||||
success = await handler.upsert_point()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
result_lst.append({
|
|
||||||
"item": asin,
|
|
||||||
"status": "ingested",
|
|
||||||
"payload": payload
|
|
||||||
})
|
|
||||||
log.info(f"Vector ingested for {asin} (ID: {index})")
|
|
||||||
else:
|
|
||||||
log.error(f"Failed to ingest vector for {asin}")
|
|
||||||
|
|
||||||
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app_router.post("/create")
|
|
||||||
async def create_collection_endpoint(
|
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
||||||
body: CreateCollectionSerializer = None
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if body is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Collection name is required")
|
|
||||||
|
|
||||||
print("collection_name: ", body.collection_name)
|
|
||||||
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name=body.collection_name,
|
|
||||||
vector=body.vector,
|
|
||||||
vector_size=body.vector_size,
|
|
||||||
payload=body.payload,
|
|
||||||
id=body.id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Create collection
|
|
||||||
result = await handler.create_collection()
|
|
||||||
|
|
||||||
# 2. Automatically call upsert_point
|
|
||||||
await handler.upsert_point()
|
|
||||||
|
|
||||||
return JSONResponse(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app_router.get("/query")
|
|
||||||
async def query_collection_endpoint(
|
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
||||||
body: QueryCollectionSerializer
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
result = []
|
|
||||||
if isinstance(body.url, str):
|
|
||||||
# Handle semicolon-separated URLs by taking the first one
|
|
||||||
target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url
|
|
||||||
log.info(f"Querying collection {body.collection_name} with URL: {target_url}")
|
|
||||||
downloaded_image_path = download_image(target_url)
|
|
||||||
query_vector = get_embedding(downloaded_image_path)
|
|
||||||
# get_embedding already returns a flat list of 768 floats
|
|
||||||
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name=body.collection_name,
|
|
||||||
vector=query_vector,
|
|
||||||
vector_size=len(query_vector),
|
|
||||||
payload={},
|
|
||||||
id=0,
|
|
||||||
client=q
|
|
||||||
)
|
|
||||||
search_result = await handler.search(
|
|
||||||
query_vector,
|
|
||||||
score_threshold=body.score_threshold,
|
|
||||||
limit=body.limit
|
|
||||||
)
|
|
||||||
if search_result:
|
|
||||||
result = [
|
|
||||||
{"id": p.id, "score": p.score, "payload": p.payload}
|
|
||||||
for p in search_result
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
result = [] # No match within threshold
|
|
||||||
|
|
||||||
elif isinstance(body.url, list):
|
|
||||||
result = []
|
|
||||||
for url in body.url:
|
|
||||||
downloaded_image_path = download_image(url)
|
|
||||||
query_vector = get_embedding(downloaded_image_path)
|
|
||||||
# get_embedding already returns a flat list of 768 floats
|
|
||||||
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name=body.collection_name,
|
|
||||||
vector=query_vector,
|
|
||||||
vector_size=len(query_vector),
|
|
||||||
payload={},
|
|
||||||
id=0,
|
|
||||||
client=q
|
|
||||||
)
|
|
||||||
search_result = await handler.search(
|
|
||||||
query_vector,
|
|
||||||
score_threshold=body.score_threshold,
|
|
||||||
limit=body.limit
|
|
||||||
)
|
|
||||||
if search_result:
|
|
||||||
result.append([
|
|
||||||
{"id": p.id, "score": p.score, "payload": p.payload}
|
|
||||||
for p in search_result
|
|
||||||
])
|
|
||||||
|
|
||||||
return JSONResponse({"results": result})
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Query failed: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app_router.put("/update")
|
|
||||||
async def update_collection_endpoint(
|
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
||||||
body: UpdateCollectionSerializer
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name=body.collection_name,
|
|
||||||
vector=body.vector,
|
|
||||||
vector_size=len(body.vector),
|
|
||||||
payload=body.payload,
|
|
||||||
id=body.id
|
|
||||||
)
|
|
||||||
result = await handler.update_collection()
|
|
||||||
return JSONResponse({"status": "success", "result": result})
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
@app_router.delete("/delete")
|
|
||||||
async def delete_collection_endpoint(
|
|
||||||
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
|
||||||
body: DeleteCollectionSerializer
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name=body.collection_name,
|
|
||||||
vector=[],
|
|
||||||
vector_size=0,
|
|
||||||
payload={},
|
|
||||||
id=0
|
|
||||||
)
|
|
||||||
result = await handler.delete_collection()
|
|
||||||
return JSONResponse(result)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Optional dependency warnings from DINOv2 internals are non-critical.
|
||||||
|
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = torch.hub.load(
|
||||||
|
'facebookresearch/dinov2',
|
||||||
|
'dinov2_vitb14'
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Device
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# Image preprocessing
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((518, 518)), # DINOv2 recommended size
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_embedding(image_path):
|
||||||
|
# Load image
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Transform
|
||||||
|
tensor = transform(image).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Generate embedding
|
||||||
|
with torch.no_grad():
|
||||||
|
embedding = model(tensor)
|
||||||
|
|
||||||
|
# Normalize embedding (important for cosine similarity)
|
||||||
|
embedding = F.normalize(embedding, p=2, dim=1)
|
||||||
|
|
||||||
|
return embedding.cpu()
|
||||||
|
|
||||||
|
# Example
|
||||||
|
emb1 = get_embedding(r"data_images\B0B39FFJHF\03.jpg")
|
||||||
|
emb2 = get_embedding(r"data_images\B09RWY127Q\03.jpg")
|
||||||
|
|
||||||
|
# Cosine similarity
|
||||||
|
similarity = torch.nn.functional.pdist(
|
||||||
|
torch.cat([emb1, emb2])
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Distance:", similarity.item())
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
import warnings
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import torch,glob
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
import os
|
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
|
||||||
from db_setup import get_qdrant_client
|
|
||||||
from vector_db_router.models import CollectionHandler
|
|
||||||
from vector_db_router.serializers import CreateCollectionSerializer
|
|
||||||
import torch.nn.functional as F
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Optional dependency warnings from DINOv2 internals are non-critical.
|
|
||||||
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = torch.hub.load(
|
|
||||||
'facebookresearch/dinov2',
|
|
||||||
'dinov2_vitb14'
|
|
||||||
)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Device
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# Image preprocessing
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.Resize((518, 518)), # DINOv2 recommended size
|
|
||||||
transforms.ToTensor(),
|
|
||||||
])
|
|
||||||
|
|
||||||
def get_embedding(image_path):
|
|
||||||
# Load image
|
|
||||||
image = Image.open(image_path).convert("RGB")
|
|
||||||
|
|
||||||
# Transform
|
|
||||||
tensor = transform(image).unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
# Generate embedding
|
|
||||||
with torch.no_grad():
|
|
||||||
embedding = model(tensor)
|
|
||||||
|
|
||||||
# Normalize embedding (important for cosine similarity)
|
|
||||||
embedding = F.normalize(embedding, p=2, dim=1)
|
|
||||||
|
|
||||||
# Return flat list (squeeze batch dim)
|
|
||||||
return embedding.squeeze(0).cpu().tolist()
|
|
||||||
|
|
||||||
def get_vectors(image_input, item):
|
|
||||||
try:
|
|
||||||
base_dir = os.path.join(os.path.dirname(__file__), "downloaded_images")
|
|
||||||
path = image_input
|
|
||||||
|
|
||||||
# If image_input is not a valid file, try to find one using the item (ASIN)
|
|
||||||
if not (path and os.path.isfile(path)):
|
|
||||||
# If path is a directory, use it as base_dir
|
|
||||||
search_base = path if (path and os.path.isdir(path)) else base_dir
|
|
||||||
glob_pattern = os.path.join(search_base, item, "*.jpg")
|
|
||||||
jpg_files = glob.glob(glob_pattern)
|
|
||||||
if jpg_files:
|
|
||||||
path = jpg_files[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate the vector for the identified image file
|
|
||||||
emb = get_embedding(path)
|
|
||||||
return emb.squeeze().tolist()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error generating vector for {item}: {e}")
|
|
||||||
return None
|
|
||||||
76
req.txt
76
req.txt
|
|
@ -1,76 +0,0 @@
|
||||||
annotated-doc==0.0.4
|
|
||||||
annotated-types==0.7.0
|
|
||||||
anyio==4.13.0
|
|
||||||
asyncmy==0.2.11
|
|
||||||
certifi==2026.4.22
|
|
||||||
click==8.3.3
|
|
||||||
contourpy==1.3.3
|
|
||||||
cuda-bindings==13.2.0
|
|
||||||
cuda-pathfinder==1.5.4
|
|
||||||
cuda-toolkit==13.0.2
|
|
||||||
cycler==0.12.1
|
|
||||||
et_xmlfile==2.0.0
|
|
||||||
fastapi==0.136.1
|
|
||||||
filelock==3.29.0
|
|
||||||
fonttools==4.62.1
|
|
||||||
fsspec==2026.4.0
|
|
||||||
greenlet==3.5.0
|
|
||||||
grpcio==1.80.0
|
|
||||||
h11==0.16.0
|
|
||||||
h2==4.3.0
|
|
||||||
hpack==4.1.0
|
|
||||||
httpcore==1.0.9
|
|
||||||
httpx==0.28.1
|
|
||||||
hyperframe==6.1.0
|
|
||||||
idna==3.13
|
|
||||||
Jinja2==3.1.6
|
|
||||||
joblib==1.5.3
|
|
||||||
kiwisolver==1.5.0
|
|
||||||
MarkupSafe==3.0.3
|
|
||||||
matplotlib==3.10.9
|
|
||||||
mpmath==1.3.0
|
|
||||||
networkx==3.6.1
|
|
||||||
numpy==2.4.4
|
|
||||||
nvidia-cublas==13.1.0.3
|
|
||||||
nvidia-cuda-cupti==13.0.85
|
|
||||||
nvidia-cuda-nvrtc==13.0.88
|
|
||||||
nvidia-cuda-runtime==13.0.96
|
|
||||||
nvidia-cudnn-cu13==9.19.0.56
|
|
||||||
nvidia-cufft==12.0.0.61
|
|
||||||
nvidia-cufile==1.15.1.6
|
|
||||||
nvidia-curand==10.4.0.35
|
|
||||||
nvidia-cusolver==12.0.4.66
|
|
||||||
nvidia-cusparse==12.6.3.3
|
|
||||||
nvidia-cusparselt-cu13==0.8.0
|
|
||||||
nvidia-nccl-cu13==2.28.9
|
|
||||||
nvidia-nvjitlink==13.0.88
|
|
||||||
nvidia-nvshmem-cu13==3.4.5
|
|
||||||
nvidia-nvtx==13.0.85
|
|
||||||
openpyxl==3.1.5
|
|
||||||
packaging==26.2
|
|
||||||
pandas==3.0.2
|
|
||||||
pillow==12.2.0
|
|
||||||
portalocker==3.2.0
|
|
||||||
protobuf==7.34.1
|
|
||||||
pydantic==2.13.4
|
|
||||||
pydantic_core==2.46.4
|
|
||||||
pyparsing==3.3.2
|
|
||||||
python-dateutil==2.9.0.post0
|
|
||||||
python-dotenv==1.2.2
|
|
||||||
qdrant-client==1.17.1
|
|
||||||
scikit-learn==1.8.0
|
|
||||||
scipy==1.17.1
|
|
||||||
setuptools==81.0.0
|
|
||||||
six==1.17.0
|
|
||||||
SQLAlchemy==2.0.49
|
|
||||||
starlette==1.0.0
|
|
||||||
sympy==1.14.0
|
|
||||||
threadpoolctl==3.6.0
|
|
||||||
torch==2.11.0
|
|
||||||
torchvision==0.26.0
|
|
||||||
tqdm==4.67.3
|
|
||||||
triton==3.6.0
|
|
||||||
typing-inspection==0.4.2
|
|
||||||
typing_extensions==4.15.0
|
|
||||||
urllib3==2.6.3
|
|
||||||
uvicorn==0.46.0
|
|
||||||
Loading…
Reference in New Issue