| from application.vectorstore.base import BaseVectorStore |
| from application.core.settings import settings |
| from application.vectorstore.document_class import Document |
|
|
| class MongoDBVectorStore(BaseVectorStore): |
| def __init__( |
| self, |
| path: str = "", |
| embeddings_key: str = "embeddings", |
| collection: str = "documents", |
| index_name: str = "vector_search_index", |
| text_key: str = "text", |
| embedding_key: str = "embedding", |
| database: str = "docsgpt", |
| ): |
| self._index_name = index_name |
| self._text_key = text_key |
| self._embedding_key = embedding_key |
| self._embeddings_key = embeddings_key |
| self._mongo_uri = settings.MONGO_URI |
| self._path = path.replace("application/indexes/", "").rstrip("/") |
| self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) |
|
|
| try: |
| import pymongo |
| except ImportError: |
| raise ImportError( |
| "Could not import pymongo python package. " |
| "Please install it with `pip install pymongo`." |
| ) |
|
|
| self._client = pymongo.MongoClient(self._mongo_uri) |
| self._database = self._client[database] |
| self._collection = self._database[collection] |
|
|
| |
| def search(self, question, k=2, *args, **kwargs): |
| query_vector = self._embedding.embed_query(question) |
|
|
| pipeline = [ |
| { |
| "$vectorSearch": { |
| "queryVector": query_vector, |
| "path": self._embedding_key, |
| "limit": k, |
| "numCandidates": k * 10, |
| "index": self._index_name, |
| "filter": { |
| "store": {"$eq": self._path} |
| } |
| } |
| } |
| ] |
|
|
| cursor = self._collection.aggregate(pipeline) |
| |
| results = [] |
| for doc in cursor: |
| text = doc[self._text_key] |
| doc.pop("_id") |
| doc.pop(self._text_key) |
| doc.pop(self._embedding_key) |
| metadata = doc |
| results.append(Document(text, metadata)) |
| return results |
| |
| def _insert_texts(self, texts, metadatas): |
| if not texts: |
| return [] |
| embeddings = self._embedding.embed_documents(texts) |
| to_insert = [ |
| {self._text_key: t, self._embedding_key: embedding, **m} |
| for t, m, embedding in zip(texts, metadatas, embeddings) |
| ] |
| |
| insert_result = self._collection.insert_many(to_insert) |
| return insert_result.inserted_ids |
| |
| def add_texts(self, |
| texts, |
| metadatas = None, |
| ids = None, |
| refresh_indices = True, |
| create_index_if_not_exists = True, |
| bulk_kwargs = None, |
| **kwargs,): |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| batch_size = 100 |
| _metadatas = metadatas or ({} for _ in texts) |
| texts_batch = [] |
| metadatas_batch = [] |
| result_ids = [] |
| for i, (text, metadata) in enumerate(zip(texts, _metadatas)): |
| texts_batch.append(text) |
| metadatas_batch.append(metadata) |
| if (i + 1) % batch_size == 0: |
| result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) |
| texts_batch = [] |
| metadatas_batch = [] |
| if texts_batch: |
| result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) |
| return result_ids |
| |
| def delete_index(self, *args, **kwargs): |
| self._collection.delete_many({"store": self._path}) |