datastore/providers/mongodb_atlas_datastore.py (165 lines of code) (raw):

import os from typing import Dict, List, Any, Optional from loguru import logger from importlib.metadata import version from motor.motor_asyncio import AsyncIOMotorClient from pymongo.driver_info import DriverInfo from pymongo import UpdateOne from datastore.datastore import DataStore from functools import cached_property from models.models import ( Document, DocumentChunk, DocumentChunkWithScore, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, ) from services.chunks import get_document_chunks from services.date import to_unix_timestamp MONGODB_CONNECTION_URI = os.environ.get("MONGODB_URI") MONGODB_DATABASE = os.environ.get("MONGODB_DATABASE", "default") MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "default") MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "default") OVERSAMPLING_FACTOR = 10 MAX_CANDIDATES = 10_000 class MongoDBAtlasDataStore(DataStore): def __init__( self, atlas_connection_uri: str = MONGODB_CONNECTION_URI, index_name: str = MONGODB_INDEX, database_name: str = MONGODB_DATABASE, collection_name: str = MONGODB_COLLECTION, oversampling_factor: float = OVERSAMPLING_FACTOR, ): """ Initialize a MongoDBAtlasDataStore instance. Parameters: - index_name (str, optional): Vector search index. If not provided, default index name is used. - database_name (str, optional): Database. If not provided, default database name is used. - collection_name (str, optional): Collection. If not provided, default collection name is used. - oversampling_factor (float, optional): Oversampling factor for data augmentation. Default is OVERSAMPLING_FACTOR. Raises: - ValueError: If index_name is not a valid string. Attributes: - index_name (str): Name of the index. - database_name (str): Name of the database. - collection_name (str): Name of the collection. - oversampling_factor (float): Oversampling factor for data augmentation. """ self.atlas_connection_uri = atlas_connection_uri self.oversampling_factor = oversampling_factor self.database_name = database_name self.collection_name = collection_name if not (index_name and isinstance(index_name, str)): raise ValueError("Provide a valid index name") self.index_name = index_name # TODO: Create index via driver https://jira.mongodb.org/browse/PYTHON-4175 # self._create_search_index(num_dimensions=1536, path="embedding", similarity="dotProduct", type="vector") @cached_property def client(self): return self._connect_to_mongodb_atlas( atlas_connection_uri=MONGODB_CONNECTION_URI ) async def upsert( self, documents: List[Document], chunk_token_size: Optional[int] = None ) -> List[str]: """ Takes in a list of Documents, chunks them, and upserts the chunks into the database. Return a list the ids of the document chunks. """ chunks = get_document_chunks(documents, chunk_token_size) return await self._upsert(chunks) async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a list of document chunks and inserts them into the database. Return a list of document ids. """ documents_to_upsert = [] inserted_ids = [] for chunk_list in chunks.values(): for chunk in chunk_list: inserted_ids.append(chunk.id) documents_to_upsert.append( UpdateOne({'_id': chunk.id}, {"$set": chunk.dict()}, upsert=True) ) logger.info(f"Upsert documents into MongoDB collection: {self.database_name}: {self.collection_name}") await self.client[self.database_name][self.collection_name].bulk_write(documents_to_upsert) logger.info("Upsert successful") return inserted_ids async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ results = [] for query in queries: query_result = await self._execute_embedding_query(query) results.append(query_result) return results async def _execute_embedding_query(self, query: QueryWithEmbedding) -> QueryResult: """ Execute a MongoDB query using vector search on the specified collection and return the result of the query, including matched documents and their scores. """ pipeline = [ { '$vectorSearch': { 'index': self.index_name, 'path': 'embedding', 'queryVector': query.embedding, 'numCandidates': min(query.top_k * self.oversampling_factor, MAX_CANDIDATES), 'limit': query.top_k } }, { '$project': { 'text': 1, 'metadata': 1, 'score': { '$meta': 'vectorSearchScore' } } } ] async with self.client[self.database_name][self.collection_name].aggregate(pipeline) as cursor: results = [ self._convert_mongodb_document_to_document_chunk_with_score(doc) async for doc in cursor ] return QueryResult( query=query.query, results=results, ) async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """ Removes documents by ids, filter, or everything in the datastore. Returns whether the operation was successful. Note that ids refer to those in the datastore, which are those of the **DocumentChunks** """ # Delete all documents from the collection if delete_all is True if delete_all: logger.info("Deleting all documents from collection") mg_filter = {} # Delete by ids elif ids: logger.info(f"Deleting documents with ids: {ids}") mg_filter = {"_id": {"$in": ids}} # Delete by filters elif filter: mg_filter = self._build_mongo_filter(filter) logger.info(f"Deleting documents with filter: {mg_filter}") # Do nothing else: logger.warning("No criteria set; nothing to delete args: ids: %s, filter: %s delete_all: %s", ids, filter, delete_all) return True try: await self.client[self.database_name][self.collection_name].delete_many(mg_filter) logger.info("Deleted documents successfully") except Exception as e: logger.error("Error deleting documents with filter: %s -- error: %s", mg_filter, e) return False return True def _convert_mongodb_document_to_document_chunk_with_score( self, document: Dict ) -> DocumentChunkWithScore: # Convert MongoDB document to DocumentChunkWithScore return DocumentChunkWithScore( id=document.get("_id"), text=document["text"], metadata=document.get("metadata"), score=document.get("score"), ) def _build_mongo_filter( self, filter: Optional[DocumentMetadataFilter] = None ) -> Dict[str, Any]: """ Generate MongoDB query filters based on the provided DocumentMetadataFilter. """ if filter is None: return {} mongo_filters = { "$and": [], } # For each field in the MetadataFilter, # check if it has a value and add the corresponding MongoDB filter expression for field, value in filter.dict().items(): if value is not None: if field == "start_date": mongo_filters["$and"].append( {"created_at": {"$gte": to_unix_timestamp(value)}} ) elif field == "end_date": mongo_filters["$and"].append( {"created_at": {"$lte": to_unix_timestamp(value)}} ) else: mongo_filters["$and"].append( {f"metadata.{field}": value} ) return mongo_filters @staticmethod def _connect_to_mongodb_atlas(atlas_connection_uri: str): """ Establish a connection to MongoDB Atlas. """ client = AsyncIOMotorClient( atlas_connection_uri, driver=DriverInfo(name="Chatgpt Retrieval Plugin", version=version("chatgpt_retrieval_plugin"))) return client