elasticsearch/helpers/vectorstore/_async/vectorstore.py (266 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging import uuid from typing import Any, Callable, Dict, List, Optional from elasticsearch import AsyncElasticsearch from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, async_bulk from elasticsearch.helpers.vectorstore import ( AsyncEmbeddingService, AsyncRetrievalStrategy, ) from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) class AsyncVectorStore: """ VectorStore is a higher-level abstraction of indexing and search. Users can pick from available retrieval strategies. Documents have up to 3 fields: - text_field: the text to be indexed and searched. - metadata: additional information about the document, either schema-free or defined by the supplied metadata_mappings. - vector_field (usually not filled by the user): the embedding vector of the text. Depending on the strategy, vector embeddings are - created by the user beforehand - created by this AsyncVectorStore class in Python - created in-stack by inference pipelines. """ def __init__( self, client: AsyncElasticsearch, *, index: str, retrieval_strategy: AsyncRetrievalStrategy, embedding_service: Optional[AsyncEmbeddingService] = None, num_dimensions: Optional[int] = None, text_field: str = "text_field", vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, user_agent: str = f"elasticsearch-py-vs/{lib_version}", custom_index_settings: Optional[Dict[str, Any]] = None, ) -> None: """ :param user_header: user agent header specific to the 3rd party integration. Used for usage tracking in Elastic Cloud. :param index: The name of the index to query. :param retrieval_strategy: how to index and search the data. See the strategies module for availble strategies. :param text_field: Name of the field with the textual data. :param vector_field: For strategies that perform embedding inference in Python, the embedding vector goes in this field. :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. :param custom_index_settings: A dictionary of custom settings for the index. This can include configurations like the number of shards, number of replicas, analysis settings, and other index-specific settings. If not provided, default settings will be used. Note that if the same setting is provided by both the user and the strategy, will raise an error. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. client = client.options(headers={"User-Agent": user_agent}) if hasattr(retrieval_strategy, "text_field"): retrieval_strategy.text_field = text_field if hasattr(retrieval_strategy, "vector_field"): retrieval_strategy.vector_field = vector_field self.client = client self.index = index self.retrieval_strategy = retrieval_strategy self.embedding_service = embedding_service self.num_dimensions = num_dimensions self.text_field = text_field self.vector_field = vector_field self.metadata_mappings = metadata_mappings self.custom_index_settings = custom_index_settings async def close(self) -> None: return await self.client.close() async def add_texts( self, texts: List[str], *, metadatas: Optional[List[Dict[str, Any]]] = None, vectors: Optional[List[List[float]]] = None, ids: Optional[List[str]] = None, refresh_indices: bool = True, create_index_if_not_exists: bool = True, bulk_kwargs: Optional[Dict[str, Any]] = None, ) -> List[str]: """Add documents to the Elasticsearch index. :param texts: List of text documents. :param metadata: Optional list of document metadata. Must be of same length as texts. :param vectors: Optional list of embedding vectors. Must be of same length as texts. :param ids: Optional list of ID strings. Must be of same length as texts. :param refresh_indices: Whether to refresh the index after deleting documents. Defaults to True. :param create_index_if_not_exists: Whether to create the index if it does not exist. Defaults to True. :param bulk_kwargs: Arguments to pass to the bulk function when indexing (for example chunk_size). :return: List of IDs of the created documents, either echoing the provided one or returning newly created ones. """ bulk_kwargs = bulk_kwargs or {} ids = ids or [str(uuid.uuid4()) for _ in texts] requests = [] if create_index_if_not_exists: await self._create_index_if_not_exists() if self.embedding_service and not vectors: vectors = await self.embedding_service.embed_documents(texts) for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} request: Dict[str, Any] = { "_op_type": "index", "_index": self.index, self.text_field: text, "metadata": metadata, "_id": ids[i], } if vectors: request[self.vector_field] = vectors[i] requests.append(request) if len(requests) > 0: try: success, failed = await async_bulk( self.client, requests, stats_only=True, refresh=refresh_indices, **bulk_kwargs, ) logger.debug(f"added texts {ids} to index") return ids except BulkIndexError as e: logger.error(f"Error adding texts: {e}") firstError = e.errors[0].get("index", {}).get("error", {}) logger.error(f"First error reason: {firstError.get('reason')}") raise e else: logger.debug("No texts to add to index") return [] async def delete( # type: ignore[no-untyped-def] self, *, ids: Optional[List[str]] = None, query: Optional[Dict[str, Any]] = None, refresh_indices: bool = True, **delete_kwargs, ) -> bool: """Delete documents from the Elasticsearch index. :param ids: List of IDs of documents to delete. :param refresh_indices: Whether to refresh the index after deleting documents. Defaults to True. :return: True if deletion was successful. """ if ids is not None and query is not None: raise ValueError("one of ids or query must be specified") elif ids is None and query is None: raise ValueError("either specify ids or query") try: if ids: body = [ {"_op_type": "delete", "_index": self.index, "_id": _id} for _id in ids ] await async_bulk( self.client, body, refresh=refresh_indices, ignore_status=404, **delete_kwargs, ) logger.debug(f"Deleted {len(body)} texts from index") else: await self.client.delete_by_query( index=self.index, query=query, refresh=refresh_indices, **delete_kwargs, ) except BulkIndexError as e: logger.error(f"Error deleting texts: {e}") firstError = e.errors[0].get("index", {}).get("error", {}) logger.error(f"First error reason: {firstError.get('reason')}") raise e return True async def search( self, *, query: Optional[str] = None, query_vector: Optional[List[float]] = None, k: int = 4, num_candidates: int = 50, fields: Optional[List[str]] = None, filter: Optional[List[Dict[str, Any]]] = None, custom_query: Optional[ Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] ] = None, ) -> List[Dict[str, Any]]: """ :param query: Input query string. :param query_vector: Input embedding vector. If given, input query string is ignored. :param k: Number of returned results. :param num_candidates: Number of candidates to fetch from data nodes in knn. :param fields: List of field names to return. :param filter: Elasticsearch filters to apply. :param custom_query: Function to modify the Elasticsearch query body before it is sent to Elasticsearch. :return: List of document hits. Includes _index, _id, _score and _source. """ if fields is None: fields = [] if "metadata" not in fields: fields.append("metadata") if self.text_field not in fields: fields.append(self.text_field) if self.embedding_service and not query_vector: if not query: raise ValueError("specify a query or a query_vector to search") query_vector = await self.embedding_service.embed_query(query) query_body = self.retrieval_strategy.es_query( query=query, query_vector=query_vector, text_field=self.text_field, vector_field=self.vector_field, k=k, num_candidates=num_candidates, filter=filter or [], ) if custom_query is not None: query_body = custom_query(query_body, query) logger.debug(f"Calling custom_query, Query body now: {query_body}") response = await self.client.search( index=self.index, **query_body, size=k, source=True, source_includes=fields, ) hits: List[Dict[str, Any]] = response["hits"]["hits"] return hits async def _create_index_if_not_exists(self) -> None: exists = await self.client.indices.exists(index=self.index) if exists.meta.status == 200: logger.debug(f"Index {self.index} already exists. Skipping creation.") return if self.retrieval_strategy.needs_inference(): if not self.num_dimensions and not self.embedding_service: raise ValueError( "retrieval strategy requires embeddings; either embedding_service " "or num_dimensions need to be specified" ) if not self.num_dimensions and self.embedding_service: vector = await self.embedding_service.embed_query("get num dimensions") self.num_dimensions = len(vector) mappings, settings = self.retrieval_strategy.es_mappings_settings( text_field=self.text_field, vector_field=self.vector_field, num_dimensions=self.num_dimensions, ) if self.custom_index_settings: conflicting_keys = set(self.custom_index_settings.keys()) & set( settings.keys() ) if conflicting_keys: raise ValueError(f"Conflicting settings: {conflicting_keys}") else: settings.update(self.custom_index_settings) if self.metadata_mappings: metadata = mappings["properties"].get("metadata", {"properties": {}}) for key in self.metadata_mappings.keys(): if key in metadata: raise ValueError(f"metadata key {key} already exists in mappings") metadata = dict(**metadata["properties"], **self.metadata_mappings) mappings["properties"]["metadata"] = {"properties": metadata} await self.retrieval_strategy.before_index_creation( client=self.client, text_field=self.text_field, vector_field=self.vector_field, ) await self.client.indices.create( index=self.index, mappings=mappings, settings=settings ) async def max_marginal_relevance_search( self, *, query: Optional[str] = None, query_embedding: Optional[List[float]] = None, embedding_service: Optional[AsyncEmbeddingService] = None, vector_field: str, k: int = 4, num_candidates: int = 20, lambda_mult: float = 0.5, fields: Optional[List[str]] = None, custom_query: Optional[ Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] ] = None, ) -> List[Dict[str, Any]]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. :param query (str): Text to look up documents similar to. :param query_embedding: Input embedding vector. If given, input query string is ignored. :param k (int): Number of Documents to return. Defaults to 4. :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. :param lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. :param fields: Other fields to get from elasticsearch source. These fields will be added to the document metadata. :return: A list of Documents selected by maximal marginal relevance. """ remove_vector_query_field_from_metadata = True if fields is None: fields = [vector_field] elif vector_field not in fields: fields.append(vector_field) else: remove_vector_query_field_from_metadata = False # Embed the query if query_embedding: query_vector = query_embedding else: if not query: raise ValueError("specify either query or query_embedding to search") elif embedding_service: query_vector = await embedding_service.embed_query(query) elif self.embedding_service: query_vector = await self.embedding_service.embed_query(query) else: raise ValueError("specify embedding_service to search with query") # Fetch the initial documents got_hits = await self.search( query=None, query_vector=query_vector, k=num_candidates, fields=fields, custom_query=custom_query, ) # Get the embeddings for the fetched documents got_embeddings = [hit["_source"][vector_field] for hit in got_hits] # Select documents using maximal marginal relevance selected_indices = maximal_marginal_relevance( query_vector, got_embeddings, lambda_mult=lambda_mult, k=k ) selected_hits = [got_hits[i] for i in selected_indices] if remove_vector_query_field_from_metadata: for hit in selected_hits: del hit["_source"][vector_field] return selected_hits