notebooks/workshop_lab_2_2.ipynb (790 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "nFbQGw2POViM"
},
"source": [
"# Lab 2-2 RAG"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iSVyZRkvmqyc"
},
"source": [
"## Setup Environment\n",
"The following code loads the environment variables required to run this notebook.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BwWfBNdUmqyd"
},
"outputs": [],
"source": [
"FILE=\"Session 2\"\n",
"\n",
"! pip install -qqq git+https://github.com/elastic/notebook-workshop-loader.git@main\n",
"from notebookworkshoploader import loader\n",
"import os\n",
"from dotenv import load_dotenv\n",
"\n",
"if os.path.isfile(\"../env\"):\n",
" load_dotenv(\"../env\", override=True)\n",
" print('Successfully loaded environment variables from local env file')\n",
"else:\n",
" loader.load_remote_env(file=FILE, env_url=\"https://notebook-workshop-api-voldmqr2bq-uc.a.run.app\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ln-8SRvAI-jS"
},
"outputs": [],
"source": [
"! pip install -qqq tiktoken==0.5.2 cohere==4.38 openai==1.3.9\n",
"! pip install -q streamlit elasticsearch elastic-apm inquirer python-dotenv\n",
"! pip install elasticsearch_llm_cache\n",
"! npm install localtunnel --loglevel=error"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E0uQujqZclf0"
},
"source": [
"# Lab RAG\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t7RmurdZNPg-"
},
"source": [
"## Main Script\n",
"Running this cell will write a file named `app.py` into the Colab environment"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Wc02OlkpOSd2",
"outputId": "c7137c69-a9ae-4b87-c8c5-980afcff9b71"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Overwriting app.py\n"
]
}
],
"source": [
"%%writefile app.py\n",
"\n",
"import os\n",
"import streamlit as st\n",
"import openai\n",
"from openai import OpenAI\n",
"from elasticsearch import Elasticsearch\n",
"import elasticapm\n",
"import base64\n",
"from elasticsearch_llm_cache.elasticsearch_llm_cache import ElasticsearchLLMCache\n",
"import time\n",
"import json\n",
"import textwrap\n",
"\n",
"######################################\n",
"# Streamlit Configuration\n",
"st.set_page_config(layout=\"wide\")\n",
"\n",
"\n",
"# wrap text when printing, because colab scrolls output to the right too much\n",
"def wrap_text(text, width):\n",
" wrapped_text = textwrap.wrap(text, width)\n",
" return '\\n'.join(wrapped_text)\n",
"\n",
"\n",
"@st.cache_data()\n",
"def get_base64(bin_file):\n",
" with open(bin_file, 'rb') as f:\n",
" data = f.read()\n",
" return base64.b64encode(data).decode()\n",
"\n",
"\n",
"def set_background(png_file):\n",
" bin_str = get_base64(png_file)\n",
" page_bg_img = '''\n",
" <style>\n",
" .stApp {\n",
" background-image: url(\"data:image/png;base64,%s\");\n",
" background-size: cover;\n",
" }\n",
" </style>\n",
" ''' % bin_str\n",
" st.markdown(page_bg_img, unsafe_allow_html=True)\n",
" return\n",
"\n",
"\n",
"set_background('images/background-dark2.jpeg')\n",
"\n",
"\n",
"######################################\n",
"\n",
"######################################\n",
"# Sidebar Options\n",
"def sidebar_bg(side_bg):\n",
" side_bg_ext = 'png'\n",
" st.markdown(\n",
" f\"\"\"\n",
" <style>\n",
" [data-testid=\"stSidebar\"] > div:first-child {{\n",
" background: url(data:image/{side_bg_ext};base64,{base64.b64encode(open(side_bg, \"rb\").read()).decode()});\n",
" }}\n",
" </style>\n",
" \"\"\",\n",
" unsafe_allow_html=True,\n",
" )\n",
"\n",
"\n",
"side_bg = './images/sidebar2-dark.png'\n",
"sidebar_bg(side_bg)\n",
"\n",
"# sidebar logo\n",
"st.markdown(\n",
" \"\"\"\n",
" <style>\n",
" [data-testid=stSidebar] [data-testid=stImage]{\n",
" text-align: center;\n",
" display: block;\n",
" margin-left: auto;\n",
" margin-right: auto;\n",
" width: 100%;\n",
" }\n",
" </style>\n",
" \"\"\", unsafe_allow_html=True\n",
")\n",
"\n",
"with st.sidebar:\n",
" st.image(\"images/elastic_logo_transp_100.png\")\n",
"\n",
"######################################\n",
"# expander markdown\n",
"st.markdown(\n",
" '''\n",
" <style>\n",
" .streamlit-expanderHeader {\n",
" background-color: gray;\n",
" color: black; # Adjust this for expander header color\n",
" }\n",
" .streamlit-expanderContent {\n",
" background-color: white;\n",
" color: black; # Expander content color\n",
" }\n",
" </style>\n",
" ''',\n",
" unsafe_allow_html=True\n",
")\n",
"\n",
"######################################\n",
"\n",
"# Configure OpenAI client\n",
"openai.api_key = os.environ['OPENAI_API_KEY']\n",
"openai.api_base = os.environ['OPENAI_API_BASE']\n",
"openai.default_model = os.environ['OPENAI_API_ENGINE']\n",
"openai.verify_ssl_certs = False\n",
"client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)\n",
"\n",
"\n",
"# Initialize Elasticsearch and APM clients\n",
"# Configure APM and Elasticsearch clients\n",
"@st.cache_resource\n",
"def initElastic():\n",
" os.environ['ELASTIC_APM_SERVICE_NAME'] = \"genai_workshop_v2_lab_2-2\"\n",
" apmclient = elasticapm.Client()\n",
" elasticapm.instrument()\n",
"\n",
" if 'ELASTIC_CLOUD_ID' in os.environ:\n",
" es = Elasticsearch(\n",
" cloud_id=os.environ['ELASTIC_CLOUD_ID'],\n",
" api_key=(os.environ['ELASTIC_APIKEY_ID']),\n",
" request_timeout=30\n",
" )\n",
" else:\n",
" es = Elasticsearch(\n",
" os.environ['ELASTIC_URL'],\n",
" basic_auth=(os.environ['ELASTIC_USER'], os.environ['ELASTIC_PASSWORD']),\n",
" request_timeout=30\n",
" )\n",
"\n",
" if os.environ['ELASTIC_PROXY'] != \"True\":\n",
" openai.api_type = os.environ['OPENAI_API_TYPE']\n",
" openai.api_version = os.environ['OPENAI_API_VERSION']\n",
"\n",
" return apmclient, es\n",
"\n",
"\n",
"apmclient, es = initElastic()\n",
"\n",
"# Set our data index\n",
"index = os.environ['ELASTIC_INDEX_DOCS']\n",
"\n",
"###############################################################\n",
"# Similarity Cache functions\n",
"# move to env if time\n",
"cache_index = \"wikipedia-cache\"\n",
"\n",
"\n",
"def clear_es_cache(es):\n",
" print('clearing cache')\n",
" match_all_query = {\"query\": {\"match_all\": {}}}\n",
" clear_response = es.delete_by_query(index=cache_index, body=match_all_query)\n",
" return clear_response\n",
"\n",
"\n",
"@elasticapm.capture_span(\"cache_search\")\n",
"def cache_query(cache, prompt_text, similarity_threshold=0.5):\n",
" hit = cache.query(prompt_text=prompt_text, similarity_threshold=similarity_threshold)\n",
"\n",
" if hit:\n",
" st.sidebar.markdown('`Cache Match Found`')\n",
" else:\n",
" st.sidebar.markdown('`Cache Miss`')\n",
"\n",
" return hit\n",
"\n",
"\n",
"@elasticapm.capture_span(\"add_to_cache\")\n",
"def add_to_cache(cache, prompt, response):\n",
" st.sidebar.markdown('`Adding response to cache`')\n",
" print('adding to cache')\n",
" print(prompt)\n",
" print(response)\n",
" resp = cache.add(prompt=prompt, response=response)\n",
" st.markdown(resp)\n",
" return resp\n",
"\n",
"\n",
"# Init Elasticsearch Cache\n",
"# Only want to attempt to create the index on first run\n",
"cache = ElasticsearchLLMCache(es_client=es,\n",
" index_name=cache_index,\n",
" create_index=False # setting only because of Streamlit behavior\n",
" )\n",
"st.sidebar.markdown('`creating Elasticsearch Cache`')\n",
"\n",
"if \"index_created\" not in st.session_state:\n",
"\n",
" st.sidebar.markdown('`running create_index`')\n",
" cache.create_index(768)\n",
"\n",
" # Set the flag so it doesn't run every time\n",
" st.session_state.index_created = True\n",
"else:\n",
" st.sidebar.markdown('`index already created, skipping`')\n",
"\n",
"\n",
"def calc_similarity(score, func_type='dot_product'):\n",
" if func_type == 'dot_product':\n",
" return (score + 1) / 2\n",
" elif func_type in ['cosine', 'l2_norm']:\n",
" # TODO\n",
" return score\n",
" else:\n",
" return score\n",
"\n",
"\n",
"###############################################################\n",
"\n",
"\n",
"def get_bm25_query(query_text, augment_method):\n",
" if augment_method == \"Full Text\":\n",
" return {\n",
" \"match\": {\n",
" \"text\": query_text\n",
" }\n",
" }\n",
" elif augment_method == \"Matching Chunk\":\n",
" return {\n",
" \"nested\": {\n",
" \"path\": \"passages\",\n",
" \"query\": {\n",
" \"bool\": {\n",
" \"must\": [\n",
" {\n",
" \"match\": {\n",
" \"passages.text\": query_text\n",
" }\n",
" }\n",
" ]\n",
" }\n",
" },\n",
" \"inner_hits\": {\n",
" \"_source\": False,\n",
" \"fields\": [\n",
" \"passages.text\"\n",
" ]\n",
" }\n",
"\n",
" }\n",
" }\n",
"\n",
"\n",
"# Run an Elasticsearch query using BM25 relevance scoring\n",
"@elasticapm.capture_span(\"bm25_search\")\n",
"def search_bm25(query_text,\n",
" es,\n",
" size=1,\n",
" augment_method=\"Full Text\",\n",
" use_hybrid=False # always false - use semantic opt for hybrid\n",
" ):\n",
" fields = [\n",
" \"text\",\n",
" \"title\",\n",
" ]\n",
"\n",
" resp = es.search(index=index,\n",
" query=get_bm25_query(query_text, augment_method),\n",
" fields=fields,\n",
" size=size,\n",
" source=False)\n",
" # print(resp)\n",
" body = resp\n",
" url = 'nothing'\n",
"\n",
" return body, url\n",
"\n",
"\n",
"@elasticapm.capture_span(\"knn_search\")\n",
"def search_knn(query_text,\n",
" es,\n",
" size=1,\n",
" augment_method=\"Full Text\",\n",
" use_hybrid=False\n",
" ):\n",
" fields = [\n",
" \"title\",\n",
" \"text\"\n",
" ]\n",
"\n",
" knn = {\n",
" \"inner_hits\": {\n",
" \"_source\": False,\n",
" \"fields\": [\n",
" \"passages.text\"\n",
" ]\n",
" },\n",
" \"field\": \"passages.embeddings\",\n",
" \"k\": size,\n",
" \"num_candidates\": 100,\n",
" \"query_vector_builder\": {\n",
" \"text_embedding\": {\n",
" \"model_id\": \"sentence-transformers__all-distilroberta-v1\",\n",
" \"model_text\": query_text\n",
" }\n",
" }\n",
" }\n",
"\n",
" rank = {\"rrf\": {}} if use_hybrid else None\n",
"\n",
" # need to get the bm25 query if we are using hybrid\n",
" # query = get_bm25_query(query_text, augment_method) if use_hybrid else None\n",
" if use_hybrid:\n",
" print('using hybrid with augment method %s' % augment_method)\n",
" query = get_bm25_query(query_text, augment_method)\n",
" print(query)\n",
" if augment_method == \"Matching Chunk\":\n",
" del query['nested']['inner_hits']\n",
" else:\n",
" print('not using hybrid')\n",
" query = None\n",
"\n",
" print(query)\n",
" print(knn)\n",
"\n",
" resp = es.search(index=index,\n",
" knn=knn,\n",
" query=query,\n",
" fields=fields,\n",
" size=size,\n",
" rank=rank,\n",
" source=False)\n",
"\n",
" # body = resp['hits']['hits'][0]['fields']['body_content'][0]\n",
" # url = resp['hits']['hits'][0]['fields']['url'][0]\n",
"\n",
" return resp, None\n",
"\n",
"\n",
"def truncate_text(text, max_tokens):\n",
" tokens = text.split()\n",
" if len(tokens) <= max_tokens:\n",
" return text\n",
"\n",
" return ' '.join(tokens[:max_tokens])\n",
"\n",
"\n",
"def build_text_obj(resp, aug_method):\n",
"\n",
" tobj = {}\n",
"\n",
" for hit in resp['hits']['hits']:\n",
" # tobj[hit['fields']['title'][0]] = []\n",
" title = hit['fields']['title'][0]\n",
" tobj.setdefault(title, [])\n",
"\n",
" if aug_method == \"Matching Chunk\":\n",
" print('hit')\n",
" print(hit)\n",
" # tobj['passages'] = []\n",
" for ihit in hit['inner_hits']['passages']['hits']['hits']:\n",
" tobj[title].append(\n",
" {'passage': ihit['fields']['passages'][0]['text'][0],\n",
" '_score': ihit['_score']}\n",
" )\n",
" elif aug_method == \"Full Text\":\n",
" tobj[title].append(\n",
" hit['fields']\n",
" )\n",
"\n",
" return tobj\n",
"\n",
"\n",
"def generate_response(query,\n",
" es,\n",
" search_method,\n",
" custom_prompt,\n",
" negative_response,\n",
" show_prompt, size=1,\n",
" augment_method=\"Full Text\",\n",
" use_hybrid=False,\n",
" show_es_response=True,\n",
" show_es_augment=True,\n",
" ):\n",
" \"\"\"\n",
" Generates a response from ChatGPT based on the given query and Search Method.\n",
" Formats the prompt, sends it to ChatGPT, and displays the results.\n",
" \"\"\"\n",
"\n",
" # Perform the search based on the specified method\n",
" search_functions = {\n",
" 'bm25': {'method': search_bm25, 'display': 'Lexical Search'},\n",
" 'knn': {'method': search_knn, 'display': 'Semantic Search'}\n",
" }\n",
" search_func = search_functions.get(search_method)['method']\n",
" if not search_func:\n",
" raise ValueError(f\"Invalid search method: {search_method}\")\n",
"\n",
" # Perform the search and format the docs\n",
" response, url = search_func(query, es, size, augment_method, use_hybrid)\n",
" es_time = time.time()\n",
" augment_text = build_text_obj(response, augment_method)\n",
"\n",
" res_col1, res_col2 = st.columns(2)\n",
" # Display the search results from ES\n",
" with res_col2:\n",
" st.header(':rainbow[Elasticsearch Response]')\n",
" st.subheader(':orange[Search Settings]')\n",
" st.write(':gray[Search Method:] :blue[%s]' % search_functions.get(search_method)['display'])\n",
" st.write(':gray[Size Setting:] :blue[%s]' % size)\n",
" st.write(':gray[Augment Setting:] :blue[%s]' % augment_method)\n",
" st.write(':gray[Using Hybrid:] :blue[%s]' % (\n",
" 'Not Applicable with Lexical' if search_method == 'bm25' else use_hybrid))\n",
"\n",
" st.subheader(':green[Augment Chunk(s) from Elasticsearch]')\n",
" if show_es_augment:\n",
" st.json(dict(augment_text))\n",
" else:\n",
" st.write(':blue[Show Augment Disabled]')\n",
"\n",
" st.subheader(':violet[Elasticsearch Response]')\n",
" if show_es_response:\n",
" st.json(dict(response))\n",
" else:\n",
" st.write(':blue[Response Received]')\n",
"\n",
" formatted_prompt = custom_prompt.replace(\"$query\", query).replace(\"$response\", str(augment_text)).replace(\n",
" \"$negResponse\", negative_response)\n",
"\n",
" with res_col1:\n",
" st.header(':orange[GenAI Response]')\n",
"\n",
" chat_response = chat_gpt(formatted_prompt, system_prompt=\"You are a helpful assistant.\")\n",
" st.markdown(chat_response)\n",
"\n",
" # Display results\n",
" if show_prompt:\n",
" st.text(\"Full prompt sent to ChatGPT:\")\n",
" st.text(wrap_text(formatted_prompt, 70))\n",
"\n",
" if negative_response not in chat_response:\n",
" pass\n",
" else:\n",
" chat_response = None\n",
"\n",
" return es_time, chat_response\n",
"\n",
"\n",
"def chat_gpt(user_prompt, system_prompt):\n",
" \"\"\"\n",
" Generates a response from ChatGPT based on the given user and system prompts.\n",
" \"\"\"\n",
" max_tokens = 1024\n",
" max_context_tokens = 4000\n",
" safety_margin = 5\n",
"\n",
" # Truncate the prompt content to fit within the model's context length\n",
" truncated_prompt = truncate_text(user_prompt, max_context_tokens - max_tokens - safety_margin)\n",
"\n",
" # Prepare the messages for the ChatGPT API\n",
" messages = [{\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": truncated_prompt}]\n",
"\n",
" # Make the OpenAI API call\n",
" # response = openai.ChatCompletion.create(model=openai.default_model, temperature=0, messages=messages,\n",
" # # TODO stream=True\n",
" # )\n",
" response = client.chat.completions.create(model=openai.default_model,\n",
" temperature=0,\n",
" messages=messages,\n",
" # TODO stream=True\n",
" )\n",
"\n",
" # Add APM metadata and return the response content\n",
" elasticapm.set_custom_context({'model': openai.default_model, 'prompt': user_prompt})\n",
" # return response[\"choices\"][0][\"message\"][\"content\"]\n",
" return response.choices[0].message.content\n",
"\n",
"\n",
"# Main chat form\n",
"st.title(\"Wikipedia RAG Demo Platform\")\n",
"\n",
"# Define the default prompt and negative response\n",
"default_prompt_intro = \"Answer this question:\"\n",
"default_response_instructions = (\"using only the information from the wikipedia documents included and nothing \"\n",
" \"else.\\nwikipedia_docs: $response\\n\")\n",
"default_negative_response = (\"If the answer is not provided in the included documentation. You are to ONLY reply with \"\n",
" \"'I'm unable to answer the question based on the information I have from wikipedia' and \"\n",
" \"nothing else.\")\n",
"\n",
"with st.form(\"chat_form\"):\n",
" query = st.text_input(\"Ask the Elastic Documentation a question:\",\n",
" placeholder='How do I secure my elastic cluster?')\n",
"\n",
" opt_col1, opt_col2 = st.columns(2)\n",
" with opt_col1:\n",
" with st.expander(\"Customize Prompt Template\"):\n",
" prompt_intro = st.text_area(\"Introduction/context of the prompt:\", value=default_prompt_intro)\n",
" prompt_query_placeholder = st.text_area(\"Placeholder for the user's query:\", value=\"$query\")\n",
" prompt_response_placeholder = st.text_area(\"Placeholder for the Elasticsearch response:\",\n",
" value=default_response_instructions)\n",
" prompt_negative_response = st.text_area(\"Negative response placeholder:\", value=default_negative_response)\n",
" prompt_closing = st.text_area(\"Closing remarks of the prompt:\",\n",
" value=\"Format the answer in complete markdown code format.\")\n",
"\n",
" combined_prompt = f\"{prompt_intro}\\n{prompt_query_placeholder}\\n{prompt_response_placeholder}\\n{prompt_negative_response}\\n{prompt_closing}\"\n",
" st.text_area(\"Preview of your custom prompt:\", value=combined_prompt, disabled=True)\n",
"\n",
" with opt_col2:\n",
" with st.expander(\"Retrieval Search and Display Options\"):\n",
" st.subheader(\"Retrieval Options\")\n",
" ret_1, ret_2 = st.columns(2)\n",
" with ret_1:\n",
" search_method = st.radio(\"Search Method\", (\"Semantic Search\", \"Lexical Search\"))\n",
" augment_method = st.radio(\"Augment Method\", (\"Full Text\", \"Matching Chunk\"))\n",
" with ret_2:\n",
" # TODO this should update the title based on the augment_method\n",
" doc_count_title = \"Number of docs or chunks to Augment with\" if augment_method == \"Full Text\" else \"Number of Matching Chunks to Retrieve\"\n",
" doc_count = st.slider(doc_count_title, min_value=1, max_value=5, value=1)\n",
"\n",
" use_hybrid = st.checkbox('Use Hybrid Search')\n",
"\n",
" st.divider()\n",
"\n",
" st.subheader(\"Display Options\")\n",
" show_es_augment = st.checkbox('Show Elasticsearch Augment Text', value=True)\n",
" show_es_response = st.checkbox('Show Elasticsearch Response', value=True)\n",
" show_full_prompt = st.checkbox('Show Full Prompt Sent to LLM')\n",
"\n",
" st.divider()\n",
"\n",
" st.subheader(\"Caching Options\")\n",
" cache_1, cache_2 = st.columns(2)\n",
" with cache_1:\n",
" use_cache = st.checkbox('Use Similarity Cache')\n",
" # Slider for adjusting similarity threshold\n",
" similarity_threshold_selection = st.slider(\n",
" \"Select Similarity Threshold (dot_product - Higher Similarity means closer)\",\n",
" min_value=0.0, max_value=2.0,\n",
" value=0.5, step=0.01)\n",
"\n",
" with cache_2:\n",
" clear_cache_butt = st.form_submit_button(':red[Clear Similarity Cache]')\n",
"\n",
" col1, col2 = st.columns(2)\n",
" with col1:\n",
" answer_button = st.form_submit_button(\"Find my answer!\")\n",
"\n",
"# Clear Cache Button\n",
"if clear_cache_butt:\n",
" st.session_state.clear_cache_clicked = True\n",
"\n",
"# Confirmation step\n",
"if st.session_state.get(\"clear_cache_clicked\", False):\n",
" apmclient.begin_transaction(\"clear_cache\")\n",
" elasticapm.label(action=\"clear_cache\")\n",
"\n",
" # Start timing\n",
" start_time = time.time()\n",
"\n",
" if st.button(\":red[Confirm Clear Cache]\"):\n",
" print('clear cache clicked')\n",
" # TODO if index doesn't exist, catch exception then create it\n",
" response = clear_es_cache(es)\n",
" st.success(\"Cache cleared successfully!\", icon=\"🤯\")\n",
" st.session_state.clear_cache_clicked = False # Reset the state\n",
"\n",
" apmclient.end_transaction(\"clear_cache\", \"success\")\n",
"\n",
"if answer_button:\n",
" search_method = \"knn\" if search_method == \"Semantic Search\" else \"bm25\"\n",
"\n",
" apmclient.begin_transaction(\"query\")\n",
" elasticapm.label(search_method=search_method)\n",
" elasticapm.label(query=query)\n",
"\n",
" # Start timing\n",
" start_time = time.time()\n",
"\n",
" if use_cache:\n",
" # check the llm cache first\n",
" st.sidebar.markdown('`Checking ES Cache`')\n",
" cache_check = cache_query(cache,\n",
" prompt_text=query,\n",
" similarity_threshold=similarity_threshold_selection\n",
" )\n",
" st.markdown(cache_check)\n",
" else:\n",
" cache_check = None\n",
" st.sidebar.markdown('`Skipping ES Cache`')\n",
"\n",
" try:\n",
"\n",
" if cache_check:\n",
" st.sidebar.markdown('`cache match, using cached results`')\n",
" st.subheader('Response from Cache')\n",
" answer_dict = json.loads(cache_check['response'][0])\n",
"\n",
" # Display response from LLM\n",
" st.header('LLM Response')\n",
" st.subheader(f\":violet[{answer_dict['recommended_movie']}]\")\n",
" # render_markdown_with_background(answer_dict['llm_answer'])\n",
" st.markdown(answer_dict['llm_answer'])\n",
"\n",
" # if answer_dict['review_id']:\n",
" # s_score = calc_similarity(query_check['_score'], func_type='dot_product')\n",
" # st.code(f\"Review ID used: {answer_dict['review_id']} | Similarity Value: {s_score:.5f}\")\n",
"\n",
" es_time = time.time()\n",
" llmAnswer = None # no need to recache the answer\n",
"\n",
" else:\n",
" # Use combined_prompt and show_full_prompt as arguments\n",
" es_time, llmAnswer = generate_response(query,\n",
" es,\n",
" search_method,\n",
" combined_prompt,\n",
" prompt_negative_response,\n",
" show_full_prompt,\n",
" doc_count,\n",
" augment_method,\n",
" use_hybrid,\n",
" show_es_response,\n",
" show_es_augment,\n",
" )\n",
" apmclient.end_transaction(\"query\", \"success\")\n",
"\n",
" if use_cache and llmAnswer:\n",
" st.sidebar.markdown('`adding prompt and response to cache`')\n",
" add_to_cache(cache, query, llmAnswer)\n",
"\n",
" # End timing and print the elapsed time\n",
" elapsed_time = time.time() - start_time\n",
" es_elapsed_time = es_time - start_time\n",
"\n",
" ct1, ct2 = st.columns(2)\n",
" with ct1:\n",
" st.subheader(\"GenAI Time taken: :red[%.2f seconds]\" % elapsed_time)\n",
"\n",
" with ct2:\n",
" st.subheader(\"ES Query Time taken: :green[%.2f seconds]\" % es_elapsed_time)\n",
"\n",
" except Exception as e:\n",
" st.error(f\"An error occurred: {str(e)}\")\n",
" apmclient.end_transaction(\"query\", \"failure\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wu0KfS0ESf6e"
},
"source": [
"## Streamlit\n",
"Running this cell will start local tunnel and generate a random URL\n",
"\n",
"Copy the IP address on the first line then open the generated URL and paste it in the input box \"Endpoint IP\"\n",
"\n",
"This will then start the Streamlit app"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cHIHFID3NBXa"
},
"outputs": [],
"source": [
"!streamlit run app.py &>/content/logs.txt & npx localtunnel --port 8501 & curl ipv4.icanhazip.com"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}