supporting-blog-content/elasticsearch_llm_cache/elasticRAG_with_cache.py (208 lines of code) (raw):
import os
import streamlit as st
import openai
from elasticsearch import Elasticsearch
from string import Template
import elasticapm
import time
from elasticsearch_llm_cache import (
ElasticsearchLLMCache, # Import the class from the file
)
## Configure OpenAI client
# openai.api_key = os.environ['OPENAI_API_KEY']
# openai.api_base = os.environ['OPENAI_API_BASE']
# openai.default_model = os.environ['OPENAI_API_ENGINE']
# openai.verify_ssl_certs = False
# Below is for Azure OpenAI
openai.api_type = os.environ["OPENAI_API_TYPE"]
openai.api_base = os.environ["OPENAI_API_BASE"]
openai.api_version = os.environ["OPENAI_API_VERSION"]
openai.verify_ssl_certs = False
engine = os.environ["OPENAI_API_ENGINE"]
# Configure APM and Elasticsearch clients
@st.cache_resource
def initElastic():
# os.environ['ELASTIC_APM_SERVICE_NAME'] = "elasticsearch_llm_cache_demo"
apmclient = elasticapm.Client()
elasticapm.instrument()
es = Elasticsearch(
cloud_id=os.environ["ELASTIC_CLOUD_ID"].strip("="),
basic_auth=(os.environ["ELASTIC_USER"], os.environ["ELASTIC_PASSWORD"]),
request_timeout=30,
)
return apmclient, es
apmclient, es = initElastic()
# Set our data index
index = os.environ["ELASTIC_INDEX_DOCS"]
# Run an Elasticsearch query using hybrid RRF scoring of KNN and BM25
@elasticapm.capture_span("knn_search")
def search_knn(query_text, es):
query = {
"bool": {
"must": [{"match": {"body_content": {"query": query_text}}}],
"filter": [{"term": {"url_path_dir3": "elasticsearch"}}],
}
}
knn = [
{
"field": "chunk-vector",
"k": 10,
"num_candidates": 10,
"filter": {
"bool": {
"filter": [
{"range": {"chunklength": {"gte": 0}}},
{"term": {"url_path_dir3": "elasticsearch"}},
]
}
},
"query_vector_builder": {
"text_embedding": {
"model_id": "sentence-transformers__msmarco-minilm-l-12-v3",
"model_text": query_text,
}
},
}
]
rank = {"rrf": {}}
fields = ["title", "url", "position", "url_path_dir3", "body_content"]
resp = es.search(
index=index,
query=query,
knn=knn,
rank=rank,
fields=fields,
size=10,
source=False,
)
body = resp["hits"]["hits"][0]["fields"]["body_content"][0]
url = resp["hits"]["hits"][0]["fields"]["url"][0]
return body, url
def truncate_text(text, max_tokens):
tokens = text.split()
if len(tokens) <= max_tokens:
return text
return " ".join(tokens[:max_tokens])
# Generate a response from ChatGPT based on the given prompt
def genAI(
prompt,
model="gpt-3.5-turbo",
max_tokens=1024,
max_context_tokens=4000,
safety_margin=5,
sys_content=None,
):
# Truncate the prompt content to fit within the model's context length
truncated_prompt = truncate_text(
prompt, max_context_tokens - max_tokens - safety_margin
)
response = openai.ChatCompletion.create(
engine=engine,
temperature=0,
messages=[
{"role": "system", "content": sys_content},
{"role": "user", "content": truncated_prompt},
],
)
# APM: add metadata labels of data we want to capture
elasticapm.label(model=model)
elasticapm.label(prompt=prompt)
elasticapm.label(total_tokens=response["usage"]["total_tokens"])
elasticapm.label(prompt_tokens=response["usage"]["prompt_tokens"])
elasticapm.label(response_tokens=response["usage"]["completion_tokens"])
if "USER_HASH" in os.environ:
elasticapm.label(user=os.environ["USER_HASH"])
return response["choices"][0]["message"]["content"]
def toLLM(resp, url, usr_prompt, sys_prompt, neg_resp, show_prompt, engine):
prompt_template = Template(usr_prompt)
prompt_formatted = prompt_template.substitute(
query=query, resp=resp, negResponse=negResponse
)
answer = genAI(prompt_formatted, engine, sys_content=sys_prompt)
# Display response from LLM
st.header("Response from LLM")
st.markdown(answer.strip())
# We don't need to return a reference URL if it wasn't useful
if not negResponse in answer:
st.write(url)
# Display full prompt if checkbox was selected
if show_prompt:
st.divider()
st.subheader("Full prompt sent to LLM")
prompt_formatted
return answer
@elasticapm.capture_span("cache_search")
def cache_query(cache, prompt_text):
return cache.query(prompt_text=query)
@elasticapm.capture_span("add_to_cache")
def add_to_cache(cache, prompt, response):
return cache.add(prompt=prompt, response=response)
# sidebar setup
st.sidebar.header("Elasticsearch LLM Cache Info")
### MAIN
# Init Elasticsearch Cache
cache = ElasticsearchLLMCache(
es_client=es,
index_name="llm_cache_test",
create_index=False, # setting only because of Streamlit behavor
)
st.sidebar.markdown(f"_creating Elasticsearch Cache_")
# Only want to attempt to create the index on first run
if "index_created" not in st.session_state:
st.sidebar.markdown("_running create_index_")
cache.create_index(768)
# Set the flag so it doesn't run every time
st.session_state.index_created = True
else:
st.sidebar.markdown("_index already created, skipping_")
# Prompt Defaults
prompt_default = """Answer this question: $query
Using only the information from this Elastic Doc: $resp
Format the answer in complete markdown code format
If the answer is not contained in the supplied doc reply '$negResponse' and nothing else"""
system_default = "You are a helpful assistant."
neg_default = "I'm unable to answer the question based on the information I have from Elastic Docs."
st.title("Elasticsearch LLM Cache Demo")
with st.form("chat_form"):
query = st.text_input(
"Ask the Elastic Documentation a question: ",
placeholder="I want to secure my elastic cluster",
)
with st.expander("Show Prompt Override Inputs"):
# Inputs for system and User prompt override
sys_prompt = st.text_area(
"create an alernative system prompt",
placeholder=system_default,
value=system_default,
)
usr_prompt = st.text_area(
"create an alternative user prompt required -> \$query, \$resp, \$negResponse",
placeholder=prompt_default,
value=prompt_default,
)
# Default Response when criteria are not met
negResponse = st.text_area(
"Create an alternative negative response",
placeholder=neg_default,
value=neg_default,
)
show_full_prompt = st.checkbox("Show Full Prompt Sent to LLM")
col1, col2 = st.columns(2)
with col1:
query_button = st.form_submit_button("Run With Cache Check")
with col2:
refresh_button = st.form_submit_button("Refresh Cache with new call to LLM")
if query_button:
apmclient.begin_transaction("query")
elasticapm.label(search_method="knn")
elasticapm.label(query=query)
# Start timing
start_time = time.time()
# check the llm cache first
query_check = cache_query(cache, prompt_text=query)
if query_check:
st.sidebar.markdown("_cache match, using cached results_")
st.subheader("Response from Cache")
st.markdown(query_check["response"][0])
# st.button('rerun without cache')
else:
st.sidebar.markdown("_no cache match, querying es and sending to LLM_")
resp, url = search_knn(query, es) # run kNN hybrid query
llmAnswer = toLLM(
resp, url, usr_prompt, sys_prompt, negResponse, show_full_prompt, engine
)
st.sidebar.markdown("_adding prompt and response to cache_")
add_to_cache(cache, query, llmAnswer)
# End timing and print the elapsed time
elapsed_time = time.time() - start_time
st.sidebar.markdown(f"_Time taken: {elapsed_time:.2f} seconds_")
st.markdown(f"_Time taken: {elapsed_time:.2f} seconds_")
apmclient.end_transaction("query", "success")
if refresh_button:
apmclient.begin_transaction("refresh_cache")
st.sidebar.markdown("_refreshing cache_")
"""
Cache Refresh idea: set an 'invalidated' flag in the
already cached document and then call the LLM
"""
elasticapm.label(search_method="knn")
elasticapm.label(query=query)
# Start timing
start_time = time.time()
st.sidebar.markdown("_skipping cache check - sending to LLM_")
resp, url = search_knn(query, es) # run kNN hybrid query
llmAnswer = toLLM(
resp, url, usr_prompt, sys_prompt, negResponse, show_full_prompt, engine
)
st.sidebar.markdown("_adding prompt and response to cache_")
add_to_cache(cache, query, llmAnswer)
# End timing and print the elapsed time
elapsed_time = time.time() - start_time
st.sidebar.markdown(f"_Time taken: {elapsed_time:.2f} seconds_")
st.markdown(f"_Time taken: {elapsed_time:.2f} seconds_")
apmclient.end_transaction("query", "success")
st.sidebar.markdown("_cache refreshed_")
apmclient.end_transaction("refresh_cache", "success")