injection completed
parent
dcaa4cc8c9
commit
26b0b7f1a9
|
|
@ -3,4 +3,5 @@
|
||||||
**data**
|
**data**
|
||||||
*pyc**
|
*pyc**
|
||||||
**pycache**
|
**pycache**
|
||||||
*agent/**
|
*agent/**
|
||||||
|
**downloaded_images**
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
MYSQL_HOST=localhost
|
|
||||||
|
|
||||||
MYSQL_PORT=3306
|
|
||||||
|
|
||||||
MYSQL_USER=root
|
|
||||||
|
|
||||||
MYSQL_PASSWORD='AmB@ig123'
|
|
||||||
|
|
||||||
MYSQL_DATABASE=listing_radar
|
|
||||||
Binary file not shown.
|
|
@ -1,3 +1,9 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add the project root to sys.path to allow imports from model_export
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from fastapi import FastAPI, status
|
from fastapi import FastAPI, status
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,24 @@ from qdrant_client.models import PointStruct
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
class CollectionHandler:
|
class CollectionHandler:
|
||||||
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, id: int):
|
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
|
||||||
|
id: int=None,
|
||||||
|
link: str=None,
|
||||||
|
asin: str=None,
|
||||||
|
category: str=None,
|
||||||
|
brand: str=None,
|
||||||
|
client: AsyncQdrantClient=None
|
||||||
|
):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
|
self.id = id
|
||||||
self.vector_size = vector_size
|
self.vector_size = vector_size
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.id = id
|
self.link = link
|
||||||
self.client = AsyncQdrantClient("localhost", port=6333)
|
self.asin = asin
|
||||||
|
self.category = category
|
||||||
|
self.brand = brand
|
||||||
|
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
|
||||||
|
|
||||||
async def create_collection(self):
|
async def create_collection(self):
|
||||||
try:
|
try:
|
||||||
|
|
@ -18,21 +29,34 @@ class CollectionHandler:
|
||||||
|
|
||||||
await self.client.create_collection(
|
await self.client.create_collection(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.COSINE),
|
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
|
||||||
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Creating payload indexes as per project logic
|
# Creating payload indexes as per project logic
|
||||||
|
|
||||||
await self.client.create_payload_index(
|
await self.client.create_payload_index(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
field_name="Product_ID",
|
field_name="link",
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
)
|
)
|
||||||
await self.client.create_payload_index(
|
await self.client.create_payload_index(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
field_name="Product_Link",
|
field_name="title",
|
||||||
field_schema=models.PayloadSchemaType.KEYWORD
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.client.create_payload_index(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
field_name="brand",
|
||||||
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
|
)
|
||||||
|
await self.client.create_payload_index(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
field_name="asin",
|
||||||
|
field_schema=models.PayloadSchemaType.KEYWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
return {"message": f"Collection {self.collection_name} created successfully"}
|
return {"message": f"Collection {self.collection_name} created successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -46,10 +70,10 @@ class CollectionHandler:
|
||||||
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
PointStruct(id=self.id, vector=self.vector, payload=self.payload)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print("Data inserted successfully")
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Insertion failed: ", e)
|
# Note: In a real app we'd use a logger here
|
||||||
|
print(f"Insertion failed for ID {self.id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def upsert_point(self):
|
async def upsert_point(self):
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,92 @@ from .serializers import (
|
||||||
UpdateCollectionSerializer,
|
UpdateCollectionSerializer,
|
||||||
DeleteCollectionSerializer
|
DeleteCollectionSerializer
|
||||||
)
|
)
|
||||||
|
from model_export.dino_image_matching import get_vectors
|
||||||
from .models import CollectionHandler
|
from .models import CollectionHandler
|
||||||
|
import os
|
||||||
app_router = APIRouter()
|
app_router = APIRouter()
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
@app_router.get("/get_vectors")
|
||||||
|
async def get_vectors_endpoint(
|
||||||
|
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
|
||||||
|
image_path:str=os.getenv("DATASET")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Construct path relative to this file
|
||||||
|
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
|
||||||
|
df = pd.read_excel(excel_path)
|
||||||
|
asin_list = df['ASIN'].dropna().astype(str).tolist()
|
||||||
|
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
|
||||||
|
|
||||||
|
# 1. Initialize/Create collection "Product"
|
||||||
|
# DINOv2 vitb14 vector size is 768
|
||||||
|
init_handler = CollectionHandler(
|
||||||
|
collection_name="Product",
|
||||||
|
vector=[],
|
||||||
|
vector_size=768,
|
||||||
|
payload={},
|
||||||
|
client=q
|
||||||
|
)
|
||||||
|
await init_handler.create_collection()
|
||||||
|
|
||||||
|
result_lst = []
|
||||||
|
for index, row in df.iterrows():
|
||||||
|
asin = str(row['ASIN'])
|
||||||
|
if pd.isna(row['ASIN']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
title = str(row['Title'])
|
||||||
|
brand = str(row['Brand'])
|
||||||
|
link = str(row['Image'])
|
||||||
|
|
||||||
|
# Call get_vectors
|
||||||
|
vector = get_vectors(image_input=image_path, item=asin)
|
||||||
|
|
||||||
|
if vector is None:
|
||||||
|
log.warning(f"Skipping {asin} due to missing image/vector")
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"asin": asin,
|
||||||
|
"title": title,
|
||||||
|
"brand": brand,
|
||||||
|
"link": link
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Ingest into Qdrant using CollectionHandler
|
||||||
|
# Use the injected client 'q' and convert index to int
|
||||||
|
handler = CollectionHandler(
|
||||||
|
collection_name="Product",
|
||||||
|
vector=vector,
|
||||||
|
vector_size=768,
|
||||||
|
payload=payload,
|
||||||
|
id=int(index),
|
||||||
|
link=link,
|
||||||
|
asin=asin,
|
||||||
|
brand=brand,
|
||||||
|
client=q
|
||||||
|
)
|
||||||
|
success = await handler.upsert_point()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
result_lst.append({
|
||||||
|
"item": asin,
|
||||||
|
"status": "ingested",
|
||||||
|
"payload": payload
|
||||||
|
})
|
||||||
|
log.info(f"Vector ingested for {asin} (ID: {index})")
|
||||||
|
else:
|
||||||
|
log.error(f"Failed to ingest vector for {asin}")
|
||||||
|
|
||||||
|
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app_router.post("/create")
|
@app_router.post("/create")
|
||||||
async def create_collection_endpoint(
|
async def create_collection_endpoint(
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,15 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
from dotenv import load_dotenv
|
||||||
import torch
|
import torch,glob
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
import os
|
||||||
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
from db_setup import get_qdrant_client
|
||||||
|
from vector_db_router.models import CollectionHandler
|
||||||
|
from vector_db_router.serializers import CreateCollectionSerializer
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
# Optional dependency warnings from DINOv2 internals are non-critical.
|
# Optional dependency warnings from DINOv2 internals are non-critical.
|
||||||
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
|
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
|
||||||
|
|
@ -42,13 +48,25 @@ def get_embedding(image_path):
|
||||||
|
|
||||||
return embedding.cpu()
|
return embedding.cpu()
|
||||||
|
|
||||||
# Example
|
def get_vectors(image_input, item):
|
||||||
emb1 = get_embedding(r"data_images\B0B39FFJHF\03.jpg")
|
try:
|
||||||
emb2 = get_embedding(r"data_images\B09RWY127Q\03.jpg")
|
base_dir = os.path.join(os.path.dirname(__file__), "downloaded_images")
|
||||||
|
path = image_input
|
||||||
# Cosine similarity
|
|
||||||
similarity = torch.nn.functional.pdist(
|
# If image_input is not a valid file, try to find one using the item (ASIN)
|
||||||
torch.cat([emb1, emb2])
|
if not (path and os.path.isfile(path)):
|
||||||
)
|
# If path is a directory, use it as base_dir
|
||||||
|
search_base = path if (path and os.path.isdir(path)) else base_dir
|
||||||
print("Distance:", similarity.item())
|
glob_pattern = os.path.join(search_base, item, "*.jpg")
|
||||||
|
jpg_files = glob.glob(glob_pattern)
|
||||||
|
if jpg_files:
|
||||||
|
path = jpg_files[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate the vector for the identified image file
|
||||||
|
emb = get_embedding(path)
|
||||||
|
return emb.squeeze().tolist()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating vector for {item}: {e}")
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.13.0
|
||||||
|
asyncmy==0.2.11
|
||||||
|
certifi==2026.4.22
|
||||||
|
click==8.3.3
|
||||||
|
contourpy==1.3.3
|
||||||
|
cuda-bindings==13.2.0
|
||||||
|
cuda-pathfinder==1.5.4
|
||||||
|
cuda-toolkit==13.0.2
|
||||||
|
cycler==0.12.1
|
||||||
|
et_xmlfile==2.0.0
|
||||||
|
fastapi==0.136.1
|
||||||
|
filelock==3.29.0
|
||||||
|
fonttools==4.62.1
|
||||||
|
fsspec==2026.4.0
|
||||||
|
greenlet==3.5.0
|
||||||
|
grpcio==1.80.0
|
||||||
|
h11==0.16.0
|
||||||
|
h2==4.3.0
|
||||||
|
hpack==4.1.0
|
||||||
|
httpcore==1.0.9
|
||||||
|
httpx==0.28.1
|
||||||
|
hyperframe==6.1.0
|
||||||
|
idna==3.13
|
||||||
|
Jinja2==3.1.6
|
||||||
|
joblib==1.5.3
|
||||||
|
kiwisolver==1.5.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
matplotlib==3.10.9
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.6.1
|
||||||
|
numpy==2.4.4
|
||||||
|
nvidia-cublas==13.1.0.3
|
||||||
|
nvidia-cuda-cupti==13.0.85
|
||||||
|
nvidia-cuda-nvrtc==13.0.88
|
||||||
|
nvidia-cuda-runtime==13.0.96
|
||||||
|
nvidia-cudnn-cu13==9.19.0.56
|
||||||
|
nvidia-cufft==12.0.0.61
|
||||||
|
nvidia-cufile==1.15.1.6
|
||||||
|
nvidia-curand==10.4.0.35
|
||||||
|
nvidia-cusolver==12.0.4.66
|
||||||
|
nvidia-cusparse==12.6.3.3
|
||||||
|
nvidia-cusparselt-cu13==0.8.0
|
||||||
|
nvidia-nccl-cu13==2.28.9
|
||||||
|
nvidia-nvjitlink==13.0.88
|
||||||
|
nvidia-nvshmem-cu13==3.4.5
|
||||||
|
nvidia-nvtx==13.0.85
|
||||||
|
openpyxl==3.1.5
|
||||||
|
packaging==26.2
|
||||||
|
pandas==3.0.2
|
||||||
|
pillow==12.2.0
|
||||||
|
portalocker==3.2.0
|
||||||
|
protobuf==7.34.1
|
||||||
|
pydantic==2.13.4
|
||||||
|
pydantic_core==2.46.4
|
||||||
|
pyparsing==3.3.2
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-dotenv==1.2.2
|
||||||
|
qdrant-client==1.17.1
|
||||||
|
scikit-learn==1.8.0
|
||||||
|
scipy==1.17.1
|
||||||
|
setuptools==81.0.0
|
||||||
|
six==1.17.0
|
||||||
|
SQLAlchemy==2.0.49
|
||||||
|
starlette==1.0.0
|
||||||
|
sympy==1.14.0
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
torch==2.11.0
|
||||||
|
torchvision==0.26.0
|
||||||
|
tqdm==4.67.3
|
||||||
|
triton==3.6.0
|
||||||
|
typing-inspection==0.4.2
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
urllib3==2.6.3
|
||||||
|
uvicorn==0.46.0
|
||||||
Loading…
Reference in New Issue