supporting-blog-content/elasticsearch_llm_cache/elasticsearch_llm_cache.py (98 lines of code) (raw):

""" Elasticsearch LLM Cache Library ================================== This library provides an Elasticsearch-based caching mechanism for Language Model (LLM) responses. Through the ElasticsearchLLMCache class, it facilitates the creation, querying, and updating of a cache index to store and retrieve LLM responses based on user prompts. Key Features: ------------- - Initialize a cache index with specified or default settings. - Create the cache index with specified mappings if it does not already exist. - Query the cache for similar prompts using a k-NN (k-Nearest Neighbors) search. - Update the 'last_hit_date' field of a document when a cache hit occurs. - Generate a vector for a given prompt using Elasticsearch's text embedding. - Add new documents (prompts and responses) to the cache. Requirements: ------------- - Elasticsearch - Python 3.6+ - elasticsearch-py library Usage Example: -------------- ```python from elasticsearch import Elasticsearch from elasticsearch_llm_cache import ElasticsearchLLMCache # Initialize Elasticsearch client es_client = Elasticsearch() # Initialize the ElasticsearchLLMCache instance llm_cache = ElasticsearchLLMCache(es_client) # Query the cache prompt_text = "What is the capital of France?" query_result = llm_cache.query(prompt_text) # Add to cache prompt = "What is the capital of France?" response = "Paris" add_result = llm_cache.add(prompt, response) ``` This library is covered in depth in the blog post Elasticsearch as a GenAI Caching Layer https://www.elastic.co/search-labs/elasticsearch-as-a-genai-caching-layer Author: Jeff Vestal Version: 1.0.0 """ from datetime import datetime from typing import Dict, List, Optional from elasticsearch import Elasticsearch import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ElasticsearchLLMCache: def __init__( self, es_client: Elasticsearch, index_name: Optional[str] = None, es_model_id: Optional[str] = "sentence-transformers__all-distilroberta-v1", create_index=True, ): """ Initialize the ElasticsearchLLMCache instance. :param es_client: Elasticsearch client object. :param index_name: Optional name for the index; defaults to 'llm_cache'. :param es_model_id: Model ID for text embedding; defaults to 'sentence-transformers__all-distilroberta-v1'. :param create_index: Boolean to determine whether to create a new index; defaults to True. """ self.es = es_client self.index_name = index_name or "llm_cache" self.es_model_id = es_model_id if create_index: self.create_index() def create_index(self, dims: Optional[int] = 768) -> Dict: """ Create the index if it does not already exist. :return: Dictionary containing information about the index creation. """ if not self.es.indices.exists(index=self.index_name): mappings = { "mappings": { "properties": { "prompt": {"type": "text"}, "response": {"type": "text"}, "create_date": {"type": "date"}, "last_hit_date": {"type": "date"}, "prompt_vector": { "type": "dense_vector", "dims": dims, "index": True, "similarity": "dot_product", }, } } } self.es.indices.create(index=self.index_name, body=mappings, ignore=400) logger.info(f"Index {self.index_name} created.") return {"cache_index": self.index_name, "created_new": True} else: logger.info(f"Index {self.index_name} already exists.") return {"cache_index": self.index_name, "created_new": False} def update_last_hit_date(self, doc_id: str): """ Update the 'last_hit_date' field of a document to the current datetime. :param doc_id: The ID of the document to update. """ update_body = {"doc": {"last_hit_date": datetime.now()}} self.es.update(index=self.index_name, id=doc_id, body=update_body) def query( self, prompt_text: str, similarity_threshold: Optional[float] = 0.5, num_candidates: Optional[int] = 1000, create_date_gte: Optional[str] = "now-1y/y", ) -> dict: """ Query the index to find similar prompts and update the `last_hit_date` for that document if a hit is found. :param prompt_text: The text of the prompt to find similar entries for. :param similarity_threshold: The similarity threshold for filtering results; defaults to 0.5. :param num_candidates: The number of candidates to consider; defaults to 1000. :param create_date_gte: The date range to consider results; defaults to "now-1y/y". :return: A dictionary containing the hits or an empty dictionary if no hits are found. """ knn = [ { "field": "prompt_vector", "k": 1, "num_candidates": num_candidates, "similarity": similarity_threshold, "query_vector_builder": { "text_embedding": { "model_id": self.es_model_id, "model_text": prompt_text, } }, "filter": {"range": {"create_date": {"gte": create_date_gte}}}, } ] fields = ["prompt", "response"] resp = self.es.search( index=self.index_name, knn=knn, fields=fields, size=1, source=False ) if resp["hits"]["total"]["value"] == 0: return {} else: doc_id = resp["hits"]["hits"][0]["_id"] self.update_last_hit_date(doc_id) return resp["hits"]["hits"][0]["fields"] def _generate_vector(self, prompt: str) -> List[float]: """ Generate a vector for a given prompt using Elasticsearch's text embedding. :param prompt: The text prompt to generate a vector for. :return: A list of floats representing the vector. """ docs = [{"text_field": prompt}] embedding = self.es.ml.infer_trained_model(model_id=self.es_model_id, docs=docs) return embedding["inference_results"][0]["predicted_value"] def add(self, prompt: str, response: str, source: Optional[str] = None) -> Dict: """ Add a new document to the index. :param prompt: The user prompt. :param response: The LLM response. :param source: Optional source identifier for the LLM. :return: A dictionary indicating the successful caching of the new prompt and response. """ prompt_vector = self._generate_vector(prompt=prompt) doc = { "prompt": prompt, "response": response, "create_date": datetime.now(), "last_hit_date": datetime.now(), "prompt_vector": prompt_vector, "source": source, # Optional } try: self.es.index(index=self.index_name, document=doc) return {"success": True} except Exception as e: logger.error(e) return {"success": False, "error": e}