injection completed

AMB_DEV
ambag12 2026-05-08 18:56:33 +05:00
parent dcaa4cc8c9
commit 26b0b7f1a9
8 changed files with 230 additions and 31 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@
*pyc** *pyc**
**pycache** **pycache**
*agent/** *agent/**
**downloaded_images**

View File

@ -1,9 +0,0 @@
MYSQL_HOST=localhost
MYSQL_PORT=3306
MYSQL_USER=root
MYSQL_PASSWORD='AmB@ig123'
MYSQL_DATABASE=listing_radar

View File

@ -1,3 +1,9 @@
import sys
import os
# Add the project root to sys.path to allow imports from model_export
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi import FastAPI, status from fastapi import FastAPI, status
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from dotenv import load_dotenv from dotenv import load_dotenv

View File

@ -3,13 +3,24 @@ from qdrant_client.models import PointStruct
from typing import Dict, Any from typing import Dict, Any
class CollectionHandler: class CollectionHandler:
def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict, id: int): def __init__(self, collection_name: str, vector: Any, vector_size: int, payload: Dict,
id: int=None,
link: str=None,
asin: str=None,
category: str=None,
brand: str=None,
client: AsyncQdrantClient=None
):
self.collection_name = collection_name self.collection_name = collection_name
self.vector = vector self.vector = vector
self.id = id
self.vector_size = vector_size self.vector_size = vector_size
self.payload = payload self.payload = payload
self.id = id self.link = link
self.client = AsyncQdrantClient("localhost", port=6333) self.asin = asin
self.category = category
self.brand = brand
self.client = client if client else AsyncQdrantClient("localhost", port=6333)
async def create_collection(self): async def create_collection(self):
try: try:
@ -18,22 +29,35 @@ class CollectionHandler:
await self.client.create_collection( await self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.COSINE), vectors_config=models.VectorParams(size=self.vector_size, distance=models.Distance.EUCLID),
optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000) optimizers_config=models.OptimizersConfigDiff(indexing_threshold=20000)
) )
# Creating payload indexes as per project logic # Creating payload indexes as per project logic
await self.client.create_payload_index( await self.client.create_payload_index(
collection_name=self.collection_name, collection_name=self.collection_name,
field_name="Product_ID", field_name="link",
field_schema=models.PayloadSchemaType.KEYWORD field_schema=models.PayloadSchemaType.KEYWORD
) )
await self.client.create_payload_index( await self.client.create_payload_index(
collection_name=self.collection_name, collection_name=self.collection_name,
field_name="Product_Link", field_name="title",
field_schema=models.PayloadSchemaType.KEYWORD field_schema=models.PayloadSchemaType.KEYWORD
) )
await self.client.create_payload_index(
collection_name=self.collection_name,
field_name="brand",
field_schema=models.PayloadSchemaType.KEYWORD
)
await self.client.create_payload_index(
collection_name=self.collection_name,
field_name="asin",
field_schema=models.PayloadSchemaType.KEYWORD
)
return {"message": f"Collection {self.collection_name} created successfully"} return {"message": f"Collection {self.collection_name} created successfully"}
except Exception as e: except Exception as e:
return {"message": str(e)} return {"message": str(e)}
@ -46,10 +70,10 @@ class CollectionHandler:
PointStruct(id=self.id, vector=self.vector, payload=self.payload) PointStruct(id=self.id, vector=self.vector, payload=self.payload)
] ]
) )
print("Data inserted successfully")
return True return True
except Exception as e: except Exception as e:
print("Insertion failed: ", e) # Note: In a real app we'd use a logger here
print(f"Insertion failed for ID {self.id}: {e}")
return False return False
async def upsert_point(self): async def upsert_point(self):

View File

@ -9,9 +9,92 @@ from .serializers import (
UpdateCollectionSerializer, UpdateCollectionSerializer,
DeleteCollectionSerializer DeleteCollectionSerializer
) )
from model_export.dino_image_matching import get_vectors
from .models import CollectionHandler from .models import CollectionHandler
import os
app_router = APIRouter() app_router = APIRouter()
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
import pandas as pd
@app_router.get("/get_vectors")
async def get_vectors_endpoint(
q: Annotated[AsyncQdrantClient, Depends(get_qdrant_client)],
image_path:str=os.getenv("DATASET")
):
try:
# Construct path relative to this file
base_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
excel_path = os.path.join(base_root, "model_export", "listing_data.xlsx")
df = pd.read_excel(excel_path)
asin_list = df['ASIN'].dropna().astype(str).tolist()
log.info(f"Generating vectors and ingesting {len(asin_list)} ASINs into Qdrant from {excel_path}")
# 1. Initialize/Create collection "Product"
# DINOv2 vitb14 vector size is 768
init_handler = CollectionHandler(
collection_name="Product",
vector=[],
vector_size=768,
payload={},
client=q
)
await init_handler.create_collection()
result_lst = []
for index, row in df.iterrows():
asin = str(row['ASIN'])
if pd.isna(row['ASIN']):
continue
title = str(row['Title'])
brand = str(row['Brand'])
link = str(row['Image'])
# Call get_vectors
vector = get_vectors(image_input=image_path, item=asin)
if vector is None:
log.warning(f"Skipping {asin} due to missing image/vector")
continue
payload = {
"asin": asin,
"title": title,
"brand": brand,
"link": link
}
# 2. Ingest into Qdrant using CollectionHandler
# Use the injected client 'q' and convert index to int
handler = CollectionHandler(
collection_name="Product",
vector=vector,
vector_size=768,
payload=payload,
id=int(index),
link=link,
asin=asin,
brand=brand,
client=q
)
success = await handler.upsert_point()
if success:
result_lst.append({
"item": asin,
"status": "ingested",
"payload": payload
})
log.info(f"Vector ingested for {asin} (ID: {index})")
else:
log.error(f"Failed to ingest vector for {asin}")
return JSONResponse({"status": "success", "message": f"Ingested {len(result_lst)} items into Product collection", "result": result_lst})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app_router.post("/create") @app_router.post("/create")
async def create_collection_endpoint( async def create_collection_endpoint(

View File

@ -1,9 +1,15 @@
import warnings import warnings
from dotenv import load_dotenv
import torch import torch,glob
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
import os
from qdrant_client import AsyncQdrantClient, models
from db_setup import get_qdrant_client
from vector_db_router.models import CollectionHandler
from vector_db_router.serializers import CreateCollectionSerializer
import torch.nn.functional as F import torch.nn.functional as F
load_dotenv()
# Optional dependency warnings from DINOv2 internals are non-critical. # Optional dependency warnings from DINOv2 internals are non-critical.
warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning) warnings.filterwarnings("ignore", message="xFormers is not available.*", category=UserWarning)
@ -42,13 +48,25 @@ def get_embedding(image_path):
return embedding.cpu() return embedding.cpu()
# Example def get_vectors(image_input, item):
emb1 = get_embedding(r"data_images\B0B39FFJHF\03.jpg") try:
emb2 = get_embedding(r"data_images\B09RWY127Q\03.jpg") base_dir = os.path.join(os.path.dirname(__file__), "downloaded_images")
path = image_input
# Cosine similarity # If image_input is not a valid file, try to find one using the item (ASIN)
similarity = torch.nn.functional.pdist( if not (path and os.path.isfile(path)):
torch.cat([emb1, emb2]) # If path is a directory, use it as base_dir
) search_base = path if (path and os.path.isdir(path)) else base_dir
glob_pattern = os.path.join(search_base, item, "*.jpg")
jpg_files = glob.glob(glob_pattern)
if jpg_files:
path = jpg_files[0]
else:
return None
print("Distance:", similarity.item()) # Generate the vector for the identified image file
emb = get_embedding(path)
return emb.squeeze().tolist()
except Exception as e:
print(f"Error generating vector for {item}: {e}")
return None

76
req.txt Normal file
View File

@ -0,0 +1,76 @@
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.13.0
asyncmy==0.2.11
certifi==2026.4.22
click==8.3.3
contourpy==1.3.3
cuda-bindings==13.2.0
cuda-pathfinder==1.5.4
cuda-toolkit==13.0.2
cycler==0.12.1
et_xmlfile==2.0.0
fastapi==0.136.1
filelock==3.29.0
fonttools==4.62.1
fsspec==2026.4.0
greenlet==3.5.0
grpcio==1.80.0
h11==0.16.0
h2==4.3.0
hpack==4.1.0
httpcore==1.0.9
httpx==0.28.1
hyperframe==6.1.0
idna==3.13
Jinja2==3.1.6
joblib==1.5.3
kiwisolver==1.5.0
MarkupSafe==3.0.3
matplotlib==3.10.9
mpmath==1.3.0
networkx==3.6.1
numpy==2.4.4
nvidia-cublas==13.1.0.3
nvidia-cuda-cupti==13.0.85
nvidia-cuda-nvrtc==13.0.88
nvidia-cuda-runtime==13.0.96
nvidia-cudnn-cu13==9.19.0.56
nvidia-cufft==12.0.0.61
nvidia-cufile==1.15.1.6
nvidia-curand==10.4.0.35
nvidia-cusolver==12.0.4.66
nvidia-cusparse==12.6.3.3
nvidia-cusparselt-cu13==0.8.0
nvidia-nccl-cu13==2.28.9
nvidia-nvjitlink==13.0.88
nvidia-nvshmem-cu13==3.4.5
nvidia-nvtx==13.0.85
openpyxl==3.1.5
packaging==26.2
pandas==3.0.2
pillow==12.2.0
portalocker==3.2.0
protobuf==7.34.1
pydantic==2.13.4
pydantic_core==2.46.4
pyparsing==3.3.2
python-dateutil==2.9.0.post0
python-dotenv==1.2.2
qdrant-client==1.17.1
scikit-learn==1.8.0
scipy==1.17.1
setuptools==81.0.0
six==1.17.0
SQLAlchemy==2.0.49
starlette==1.0.0
sympy==1.14.0
threadpoolctl==3.6.0
torch==2.11.0
torchvision==0.26.0
tqdm==4.67.3
triton==3.6.0
typing-inspection==0.4.2
typing_extensions==4.15.0
urllib3==2.6.3
uvicorn==0.46.0