datastore/providers/elasticsearch_datastore.py (277 lines of code) (raw):

import os from typing import Dict, List, Any, Optional import elasticsearch from elasticsearch import Elasticsearch, helpers from loguru import logger from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkWithScore, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, ) from services.date import to_unix_timestamp ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "http://localhost:9200") ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID") ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME") ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD") ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY") ELASTICSEARCH_INDEX = os.environ.get("ELASTICSEARCH_INDEX") ELASTICSEARCH_REPLICAS = int(os.environ.get("ELASTICSEARCH_REPLICAS", "1")) ELASTICSEARCH_SHARDS = int(os.environ.get("ELASTICSEARCH_SHARDS", "1")) VECTOR_SIZE = int(os.environ.get("EMBEDDING_DIMENSION", 256)) UPSERT_BATCH_SIZE = 100 class ElasticsearchDataStore(DataStore): def __init__( self, index_name: Optional[str] = None, vector_size: int = VECTOR_SIZE, similarity: str = "cosine", replicas: int = ELASTICSEARCH_REPLICAS, shards: int = ELASTICSEARCH_SHARDS, recreate_index: bool = True, ): """ Args: index_name: Name of the index to be used vector_size: Size of the embedding stored in a collection similarity: Any of "cosine" / "l2_norm" / "dot_product". """ assert similarity in [ "cosine", "l2_norm", "dot_product", ], "Similarity must be one of 'cosine' / 'l2_norm' / 'dot_product'." assert replicas > 0, "Replicas must be greater than or equal to 0." assert shards > 0, "Shards must be greater than or equal to 0." self.client = connect_to_elasticsearch( ELASTICSEARCH_URL, ELASTICSEARCH_CLOUD_ID, ELASTICSEARCH_API_KEY, ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD, ) assert ( index_name != "" or ELASTICSEARCH_INDEX != "" ), "Please provide an index name." self.index_name = index_name or ELASTICSEARCH_INDEX or "" replicas = replicas or ELASTICSEARCH_REPLICAS shards = shards or ELASTICSEARCH_SHARDS # Set up the collection so the documents might be inserted or queried self._set_up_index(vector_size, similarity, replicas, shards, recreate_index) async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a list of document chunks and inserts them into the database. Return a list of document ids. """ actions = [] for _, chunkList in chunks.items(): for chunk in chunkList: actions = ( actions + self._convert_document_chunk_to_es_document_operation(chunk) ) self.client.bulk(operations=actions, index=self.index_name) return list(chunks.keys()) async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ searches = self._convert_queries_to_msearch_query(queries) results = self.client.msearch(searches=searches) return [ QueryResult( query=query.query, results=[ self._convert_hit_to_document_chunk_with_score(hit) for hit in result["hits"]["hits"] ], ) for query, result in zip(queries, results["responses"]) ] async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """ Removes vectors by ids, filter, or everything in the datastore. Returns whether the operation was successful. """ # Delete all vectors from the index if delete_all is True if delete_all: try: logger.info(f"Deleting all vectors from index") self.client.delete_by_query( index=self.index_name, query={"match_all": {}} ) logger.info(f"Deleted all vectors successfully") return True except Exception as e: logger.error(f"Error deleting all vectors: {e}") raise e # Convert the metadata filter object to a dict with elasticsearch filter expressions es_filters = self._get_es_filters(filter) # Delete vectors that match the filter from the index if the filter is not empty if es_filters != {}: try: logger.info(f"Deleting vectors with filter {es_filters}") self.client.delete_by_query(index=self.index_name, query=es_filters) logger.info(f"Deleted vectors with filter successfully") except Exception as e: logger.error(f"Error deleting vectors with filter: {e}") raise e if ids: try: documents_to_delete = [doc_id for doc_id in ids] logger.info(f"Deleting {len(documents_to_delete)} documents") res = self.client.delete_by_query( index=self.index_name, query={"terms": {"metadata.document_id": documents_to_delete}}, ) logger.info(f"Deleted documents successfully") except Exception as e: logger.error(f"Error deleting documents: {e}") raise e return True def _get_es_filters( self, filter: Optional[DocumentMetadataFilter] = None ) -> Dict[str, Any]: if filter is None: return {} es_filters = { "bool": { "must": [], } } # For each field in the MetadataFilter, check if it has a value and add the corresponding pinecone filter expression # For start_date and end_date, uses the range query - gte and lte operators respectively # For other fields, uses the term query for field, value in filter.dict().items(): if value is not None: if field == "start_date": es_filters["bool"]["must"].append( {"range": {"created_at": {"gte": to_unix_timestamp(value)}}} ) elif field == "end_date": es_filters["bool"]["must"].append( {"range": {"created_at": {"lte": to_unix_timestamp(value)}}} ) else: es_filters["bool"]["must"].append( {"term": {f"metadata.{field}": value}} ) return es_filters def _convert_document_chunk_to_es_document_operation( self, document_chunk: DocumentChunk ) -> List[Dict]: created_at = ( to_unix_timestamp(document_chunk.metadata.created_at) if document_chunk.metadata.created_at is not None else None ) action_and_metadata = { "index": { "_index": self.index_name, "_id": document_chunk.id, } } source = { "id": document_chunk.id, "text": document_chunk.text, "metadata": document_chunk.metadata.dict(), "created_at": created_at, "embedding": document_chunk.embedding, } return [action_and_metadata, source] def _convert_queries_to_msearch_query(self, queries: List[QueryWithEmbedding]): searches = [] for query in queries: searches.append({"index": self.index_name}) searches.append( { "_source": True, "knn": { "field": "embedding", "query_vector": query.embedding, "k": query.top_k, "num_candidates": query.top_k, }, "size": query.top_k, } ) return searches def _convert_hit_to_document_chunk_with_score(self, hit) -> DocumentChunkWithScore: return DocumentChunkWithScore( id=hit["_id"], text=hit["_source"]["text"], # type: ignore metadata=hit["_source"]["metadata"], # type: ignore embedding=hit["_source"]["embedding"], # type: ignore score=hit["_score"], ) def _set_up_index( self, vector_size: int, similarity: str, replicas: int, shards: int, recreate_index: bool, ) -> None: if recreate_index: self._recreate_index(similarity, vector_size, replicas, shards) try: index_mapping = self.client.indices.get_mapping(index=self.index_name) current_similarity = index_mapping[self.index_name]["mappings"]["properties"]["embedding"]["similarity"] # type: ignore current_vector_size = index_mapping[self.index_name]["mappings"]["properties"]["embedding"]["dims"] # type: ignore if current_similarity != similarity: raise ValueError( f"Collection '{self.index_name}' already exists in Elasticsearch, " f"but it is configured with a similarity '{current_similarity}'. " f"If you want to use that collection, but with a different " f"similarity, please set `recreate_index=True` argument." ) if current_vector_size != vector_size: raise ValueError( f"Collection '{self.index_name}' already exists in Elasticsearch, " f"but it is configured with a vector size '{current_vector_size}'. " f"If you want to use that collection, but with a different " f"vector size, please set `recreate_index=True` argument." ) except elasticsearch.exceptions.NotFoundError: self._recreate_index(similarity, vector_size, replicas, shards) def _recreate_index( self, similarity: str, vector_size: int, replicas: int, shards: int ) -> None: settings = { "index": { "number_of_shards": shards, "number_of_replicas": replicas, "refresh_interval": "1s", } } mappings = { "properties": { "embedding": { "type": "dense_vector", "dims": vector_size, "index": True, "similarity": similarity, } } } self.client.indices.delete( index=self.index_name, ignore_unavailable=True, allow_no_indices=True ) self.client.indices.create( index=self.index_name, mappings=mappings, settings=settings ) def connect_to_elasticsearch( elasticsearch_url=None, cloud_id=None, api_key=None, username=None, password=None ): # Check if both elasticsearch_url and cloud_id are defined if elasticsearch_url and cloud_id: raise ValueError( "Both elasticsearch_url and cloud_id are defined. Please provide only one." ) # Initialize connection parameters dictionary connection_params = {} # Define the connection based on the provided parameters if elasticsearch_url: connection_params["hosts"] = [elasticsearch_url] elif cloud_id: connection_params["cloud_id"] = cloud_id else: raise ValueError("Please provide either elasticsearch_url or cloud_id.") # Add authentication details based on the provided parameters if api_key: connection_params["api_key"] = api_key elif username and password: connection_params["basic_auth"] = (username, password) else: logger.warning( "No authentication details provided. Please consider using an api_key or username and password to secure your connection." ) # Establish the Elasticsearch client connection es_client = Elasticsearch(**connection_params) try: es_client.info() except Exception as e: logger.error(f"Error connecting to Elasticsearch: {e}") raise e return es_client