| | """ |
| | Vector database service for interacting with Qdrant |
| | """ |
| |
|
| | from typing import List, Dict, Any |
| |
|
| | from fastapi import HTTPException |
| | from qdrant_client import QdrantClient |
| | from qdrant_client.models import Distance, PointStruct, VectorParams |
| |
|
| | class VectorDatabaseClient: |
| | """Class for interacting with Qdrant vector database""" |
| | |
| | def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int): |
| | self.url = url |
| | self.api_key = api_key |
| | self.collection_name = collection_name |
| | self.embedding_size = embedding_size |
| | self.client = QdrantClient(url=url, api_key=api_key) |
| | |
| | def ensure_collection_exists(self) -> None: |
| | """Ensure the Qdrant collection exists""" |
| | collections = self.client.get_collections() |
| | collection_names = [c.name for c in collections.collections] |
| | |
| | if self.collection_name not in collection_names: |
| | self.client.create_collection( |
| | collection_name=self.collection_name, |
| | vectors_config=VectorParams( |
| | size=self.embedding_size, |
| | distance=Distance.COSINE |
| | ) |
| | ) |
| | print(f"✅ Collection '{self.collection_name}' created.") |
| | else: |
| | print(f"ℹ️ Collection '{self.collection_name}' already exists.") |
| | |
| | def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None: |
| | """Add an image embedding to the database""" |
| | self.client.upsert( |
| | collection_name=self.collection_name, |
| | points=[ |
| | PointStruct( |
| | id=image_id, |
| | vector=embedding, |
| | payload=payload |
| | ) |
| | ] |
| | ) |
| | |
| | def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]: |
| | """Search for similar images using an embedding vector""" |
| | results = self.client.search( |
| | collection_name=self.collection_name, |
| | query_vector=embedding, |
| | limit=limit |
| | ) |
| | |
| | return [ |
| | { |
| | "id": r.id, |
| | "score": r.score, |
| | "payload": r.payload |
| | } |
| | for r in results |
| | ] |
| | |
| | def list_collections(self) -> List[str]: |
| | """List all collections in the database""" |
| | return [c.name for c in self.client.get_collections().collections] |
| |
|