Compare commits
No commits in common. "6ed79621b06617e82194c2b73eaeeba2b98bc3dd" and "56f52a4edcf6913ab0558a1e8afa572cbf30a57d" have entirely different histories.
6ed79621b0
...
56f52a4edc
|
|
@ -5,5 +5,5 @@
|
||||||
**pycache**
|
**pycache**
|
||||||
*agent/**
|
*agent/**
|
||||||
**downloaded_images**
|
**downloaded_images**
|
||||||
**model_export**
|
__pycache__/
|
||||||
**cpython**
|
*.pyc
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import logging
|
|
||||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class CollectionHandler:
|
class CollectionHandler:
|
||||||
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
||||||
|
|
@ -82,17 +79,16 @@ class CollectionHandler:
|
||||||
async def upsert_point(self):
|
async def upsert_point(self):
|
||||||
return await self.insertion()
|
return await self.insertion()
|
||||||
|
|
||||||
async def search(self, query_vector, score_threshold: float = 0.3, limit: int = 10):
|
async def search(self, query_vector):
|
||||||
try:
|
try:
|
||||||
result = await self.client.query_points(
|
result = await self.client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query=query_vector,
|
query_vector=query_vector,
|
||||||
score_threshold=score_threshold,
|
limit=10
|
||||||
limit=limit
|
|
||||||
)
|
)
|
||||||
return result.points
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Search failed: {e}")
|
print("Search failed: ", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_collection(self):
|
async def update_collection(self):
|
||||||
|
|
|
||||||
|
|
@ -1,90 +0,0 @@
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
def download_image(url: str, filename: str = None) -> str:
|
|
||||||
"""
|
|
||||||
Download an image from URL and save it in data/temp/ folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): Image URL
|
|
||||||
filename (str, optional): Custom filename. If None, extracted from URL.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Full path to the downloaded image
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get project root directory (where your main script is)
|
|
||||||
root_dir = Path(os.path.dirname(os.path.abspath(__file__))).parent
|
|
||||||
|
|
||||||
# Create data/temp folder structure
|
|
||||||
temp_dir = root_dir / "data" / "temp"
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Generate filename if not provided
|
|
||||||
if not filename:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
filename = os.path.basename(parsed_url.path)
|
|
||||||
if not filename or "." not in filename:
|
|
||||||
# Fallback filename
|
|
||||||
ext = filename.split('.')[-1] if '.' in filename else 'jpg'
|
|
||||||
filename = f"image_{hash(url) % 100000}.{ext}"
|
|
||||||
|
|
||||||
# Ensure filename has extension
|
|
||||||
if '.' not in filename:
|
|
||||||
filename += ".jpg"
|
|
||||||
|
|
||||||
file_path = temp_dir / filename
|
|
||||||
|
|
||||||
# Download the image
|
|
||||||
response = requests.get(url, stream=True, timeout=30)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# Save image
|
|
||||||
with open(file_path, 'wb') as f:
|
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
print(f"✅ Image downloaded: {file_path}")
|
|
||||||
return str(file_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to download image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def read_image(image_path: str) -> Image.Image:
|
|
||||||
"""
|
|
||||||
Read an image from the given path and return a PIL Image object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_path (str): Path to the image file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL.Image.Image: Loaded image
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If image doesn't exist
|
|
||||||
Exception: For other image loading errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
raise FileNotFoundError(f"Image not found at path: {image_path}")
|
|
||||||
|
|
||||||
# Open the image
|
|
||||||
image = Image.open(image_path)
|
|
||||||
|
|
||||||
# Convert to RGB (important for DINOv2 and most models)
|
|
||||||
if image.mode != "RGB":
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
print(f"✅ Image loaded successfully: {image_path} | Size: {image.size}")
|
|
||||||
return image
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
print(f"❌ File not found: {e}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to read image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
@ -10,9 +10,7 @@ class CreateCollectionSerializer(BaseModel):
|
||||||
|
|
||||||
class QueryCollectionSerializer(BaseModel):
|
class QueryCollectionSerializer(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
url: str
|
query_vector: List[float]
|
||||||
score_threshold: float = 0.3 # Euclidean distance — lower = more similar. 0.3 = very tight match
|
|
||||||
limit: int = 10
|
|
||||||
|
|
||||||
class UpdateCollectionSerializer(BaseModel):
|
class UpdateCollectionSerializer(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ from db_setup import get_qdrant_client
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import Depends, HTTPException, APIRouter
|
from fastapi import Depends, HTTPException, APIRouter
|
||||||
from qdrant_client import AsyncQdrantClient
|
from qdrant_client import AsyncQdrantClient
|
||||||
from .plugins import download_image,read_image
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from .serializers import (
|
from .serializers import (
|
||||||
CreateCollectionSerializer,
|
CreateCollectionSerializer,
|
||||||
|
|
@ -10,7 +9,7 @@ from .serializers import (
|
||||||
UpdateCollectionSerializer,
|
UpdateCollectionSerializer,
|
||||||
DeleteCollectionSerializer
|
DeleteCollectionSerializer
|
||||||
)
|
)
|
||||||
from model_export.dino_image_matching import get_vectors,get_embedding
|
from model_export.dino_image_matching import get_vectors
|
||||||
from .models import CollectionHandler
|
from .models import CollectionHandler
|
||||||
import os
|
import os
|
||||||
app_router = APIRouter()
|
app_router = APIRouter()
|
||||||
|
|
@ -133,65 +132,16 @@ async def query_collection_endpoint(
|
||||||
body: QueryCollectionSerializer
|
body: QueryCollectionSerializer
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = []
|
|
||||||
if isinstance(body.url, str):
|
|
||||||
# Handle semicolon-separated URLs by taking the first one
|
|
||||||
target_url = body.url.split(';')[0].strip() if ';' in body.url else body.url
|
|
||||||
log.info(f"Querying collection {body.collection_name} with URL: {target_url}")
|
|
||||||
downloaded_image_path = download_image(target_url)
|
|
||||||
query_vector = get_embedding(downloaded_image_path)
|
|
||||||
# get_embedding already returns a flat list of 768 floats
|
|
||||||
|
|
||||||
handler = CollectionHandler(
|
handler = CollectionHandler(
|
||||||
collection_name=body.collection_name,
|
collection_name=body.collection_name,
|
||||||
vector=query_vector,
|
vector=body.query_vector,
|
||||||
vector_size=len(query_vector),
|
vector_size=len(body.query_vector),
|
||||||
payload={},
|
payload={},
|
||||||
id=0,
|
id=0
|
||||||
client=q
|
|
||||||
)
|
)
|
||||||
search_result = await handler.search(
|
result = await handler.search(body.query_vector)
|
||||||
query_vector,
|
return JSONResponse({"results": str(result)})
|
||||||
score_threshold=body.score_threshold,
|
|
||||||
limit=body.limit
|
|
||||||
)
|
|
||||||
if search_result:
|
|
||||||
result = [
|
|
||||||
{"id": p.id, "score": p.score, "payload": p.payload}
|
|
||||||
for p in search_result
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
result = [] # No match within threshold
|
|
||||||
|
|
||||||
elif isinstance(body.url, list):
|
|
||||||
result = []
|
|
||||||
for url in body.url:
|
|
||||||
downloaded_image_path = download_image(url)
|
|
||||||
query_vector = get_embedding(downloaded_image_path)
|
|
||||||
# get_embedding already returns a flat list of 768 floats
|
|
||||||
|
|
||||||
handler = CollectionHandler(
|
|
||||||
collection_name=body.collection_name,
|
|
||||||
vector=query_vector,
|
|
||||||
vector_size=len(query_vector),
|
|
||||||
payload={},
|
|
||||||
id=0,
|
|
||||||
client=q
|
|
||||||
)
|
|
||||||
search_result = await handler.search(
|
|
||||||
query_vector,
|
|
||||||
score_threshold=body.score_threshold,
|
|
||||||
limit=body.limit
|
|
||||||
)
|
|
||||||
if search_result:
|
|
||||||
result.append([
|
|
||||||
{"id": p.id, "score": p.score, "payload": p.payload}
|
|
||||||
for p in search_result
|
|
||||||
])
|
|
||||||
|
|
||||||
return JSONResponse({"results": result})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Query failed: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app_router.put("/update")
|
@app_router.put("/update")
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,7 @@ def get_embedding(image_path):
|
||||||
# Normalize embedding (important for cosine similarity)
|
# Normalize embedding (important for cosine similarity)
|
||||||
embedding = F.normalize(embedding, p=2, dim=1)
|
embedding = F.normalize(embedding, p=2, dim=1)
|
||||||
|
|
||||||
# Return flat list (squeeze batch dim)
|
return embedding.cpu()
|
||||||
return embedding.squeeze(0).cpu().tolist()
|
|
||||||
|
|
||||||
def get_vectors(image_input, item):
|
def get_vectors(image_input, item):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue