notebooks/genai_colab_lab_4.ipynb (1,140 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "nFbQGw2POViM" }, "source": [ "# Lab 4 - RAG" ] }, { "cell_type": "markdown", "metadata": { "id": "iSVyZRkvmqyc" }, "source": [ "## Setup Environment\n", "The following code loads the environment variables, images for the RAG App, and libraries required to run this notebook.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BwWfBNdUmqyd" }, "outputs": [], "source": [ "FILE=\"GenAI Lab 4\"\n", "\n", "! pip install -qqq git+https://github.com/elastic/notebook-workshop-loader.git@main\n", "from notebookworkshoploader import loader\n", "import os\n", "from dotenv import load_dotenv\n", "\n", "if os.path.isfile(\"../env\"):\n", " load_dotenv(\"../env\", override=True)\n", " print('Successfully loaded environment variables from local env file')\n", "else:\n", " loader.load_remote_env(file=FILE, env_url=\"https://notebook-workshop-api-voldmqr2bq-uc.a.run.app\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nufMOSMbZ5Tr" }, "outputs": [], "source": [ "! git clone \"https://github.com/elastic/genai-workshop-colab.git\"\n", "! cd genai-workshop-colab; git checkout main; cd ..; cp -r ./genai-workshop-colab/notebooks/images images; cp -r ./genai-workshop-colab/notebooks/.streamlit .streamlit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ln-8SRvAI-jS" }, "outputs": [], "source": [ "! pip install -qqq tiktoken==0.5.2 cohere==4.38 openai==1.3.9\n", "! pip install -qqq streamlit==1.30.0 elasticsearch==8.12.0 elastic-apm==6.20.0 inquirer==3.2.1 python-dotenv==1.0.1\n", "! pip install -qqq elasticsearch-llm-cache==0.9.5\n", "! npm install localtunnel --loglevel=error" ] }, { "cell_type": "markdown", "metadata": { "id": "E0uQujqZclf0" }, "source": [ "## <font color=Green>Labs</font>\n" ] }, { "cell_type": "markdown", "metadata": { "id": "IvZYvYkE62Df" }, "source": [ "### <font color=Orange>Lab 4.1 - Gathering Semantic documents from Elasticsearch</font>\n", "This first exercise will allow us to see an example of returing semantically matching documents from Elasticsearch.\n", "\n", "It is not too important to understand all the Elasticsearch DSL syntax at this stage.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DsCwwEc95qv8" }, "source": [ "#### Run the code block below to set up the query function\n", "---\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "l7lu2VBg6vMN" }, "outputs": [], "source": [ "import openai\n", "from elasticsearch import Elasticsearch\n", "import time\n", "import json\n", "import textwrap\n", "\n", "\n", "index = os.environ['ELASTIC_INDEX_DOCS_W']\n", "\n", "# Create Elasticsearch Connection\n", "es = Elasticsearch(\n", " cloud_id=os.environ['ELASTIC_CLOUD_ID_W'],\n", " api_key=(os.environ['ELASTIC_APIKEY_ID_W']),\n", " request_timeout=30\n", " )\n", "\n", "\n", "# Search Function\n", "def es_hybrid_search(question):\n", " query = {\n", " \"nested\": {\n", " \"path\": \"passages\",\n", " \"query\": {\n", " \"bool\": {\n", " \"must\": [\n", " {\n", " \"match\": {\n", " \"passages.text\": question\n", " }\n", " }\n", " ]\n", " }\n", " }\n", " }\n", " }\n", "\n", " knn = {\n", " \"inner_hits\": {\n", " \"_source\": False,\n", " \"fields\": [\n", " \"passages.text\"\n", " ]\n", " },\n", " \"field\": \"passages.embeddings\",\n", " \"k\": 5,\n", " \"num_candidates\": 100,\n", " \"query_vector_builder\": {\n", " \"text_embedding\": {\n", " \"model_id\": \"sentence-transformers__all-distilroberta-v1\",\n", " \"model_text\": question\n", " }\n", " }\n", " }\n", "\n", " rank = {\n", " \"rrf\": {}\n", " }\n", "\n", " fields = [\n", " \"title\",\n", " \"text\"\n", " ]\n", "\n", " size = 5\n", "\n", " resp = es.search(index=index,\n", " #query=query,\n", " knn=knn,\n", " fields=fields,\n", " size=size,\n", " #rank=rank,\n", " source=False\n", " )\n", "\n", " title_text = []\n", " for doc in resp['hits']['hits']:\n", " title_text.append( { 'title' : doc['fields']['title'][0],\n", " 'passage' : doc['inner_hits']['passages']['hits']['hits'][0]['fields']['passages'][0]['text'][0] }\n", " )\n", "\n", " return title_text" ] }, { "cell_type": "markdown", "metadata": { "id": "eKBumt6W68wE" }, "source": [ "#### Example Semantic Search With Elastic\n", "Querying semantic search using the [sentence-transformers/all-distilroberta-v1](https://huggingface.co/sentence-transformers/all-distilroberta-v1) model." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "h4hlknOP-Tba" }, "outputs": [], "source": [ "user_question = \"Who is Batman?\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "qpHyxzev4WZm", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "dd3c5a07-806e-4bce-b2be-89648335d6c2" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Wikipedia titles returned:\n", "\n", "0 - David Cain (character)\n", "1 - Batman\n", "2 - Bruce Wayne (Dark Knight trilogy)\n", "3 - Batman Beyond\n", "4 - We Are Robin\n" ] } ], "source": [ "es_augment_docs = es_hybrid_search(user_question)\n", "\n", "print('Wikipedia titles returned:\\n')\n", "for hit, wiki in enumerate(es_augment_docs):\n", " print(f\"{hit} - {wiki['title'] }\" )" ] }, { "cell_type": "markdown", "metadata": { "id": "dPVcfU_26rGI" }, "source": [ "### <font color=Orange>Lab 4.2 - Sending Elasticsearch docs with a prompt for a RAG response</font>" ] }, { "cell_type": "markdown", "metadata": { "id": "UZRE3N0q61L3" }, "source": [ "#### Run the code below to set up the LLM Connection" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "aWeL5ANw65ND" }, "outputs": [], "source": [ "import openai\n", "from openai import OpenAI\n", "import textwrap\n", "\n", "\n", "# Configure OpenAI client\n", "openai.api_key = os.environ['OPENAI_API_KEY']\n", "openai.api_base = os.environ['OPENAI_API_BASE']\n", "openai.default_model = os.environ['OPENAI_API_ENGINE']\n", "openai.verify_ssl_certs = False\n", "client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)\n", "\n", "if os.environ['ELASTIC_PROXY'] != \"True\":\n", " openai.api_type = os.environ['OPENAI_API_TYPE']\n", " openai.api_version = os.environ['OPENAI_API_VERSION']\n", "\n", "\n", "# Text wrapper for colab readibility\n", "def wrap_text(text):\n", " wrapped_text = textwrap.wrap(text, 70)\n", " return '\\n'.join(wrapped_text)\n", "\n", "\n", "# Function to connect with LLM\n", "def chat_gpt(client, question, passages):\n", "\n", " system_prompt=\"You are a helpful assistant who answers questions from provided Wikipedia articles.\"\n", " user_prompt = f'''Answer the followng question: {question}\n", " using only the wikipedia `passages` provided.\n", " If the answer is not provided in the `passages` respond ONLY with:\n", " \"I am unable to answer the user's question from the provided passage\" and nothing else.\n", "\n", " passages: {passages}\n", "\n", " AI response:\n", " '''\n", "\n", " # Prepare the messages for the ChatGPT API\n", " messages = [{\"role\": \"system\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": user_prompt}]\n", "\n", " response = client.chat.completions.create(model=openai.default_model,\n", " temperature=0.2,\n", " messages=messages,\n", " )\n", " return response\n", "# return response.choices[0].message.content" ] }, { "cell_type": "markdown", "metadata": { "id": "pQ4ZijSv65tQ" }, "source": [ "#### Pass the full prompt and wiki passages to LLM" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "MR-XrChD6-E0", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7ace879d-8403-4625-8fb4-3b88fc07d42e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "User Question: \n", "Who is Batman?\n", "\n", "AI response:\n", "Batman is a superhero appearing in American comic books published by\n", "DC Comics. The character was created by artist Bob Kane and writer\n", "Bill Finger, and debuted in the 27th issue of the comic book Detective\n", "Comics on March 30, 1939. In the DC Universe continuity, Batman is the\n", "alias of Bruce Wayne, a wealthy American playboy, philanthropist, and\n", "industrialist who resides in Gotham City. Batman's origin story\n", "features him swearing vengeance against criminals after witnessing the\n", "murder of his parents Thomas and Martha as a child, a vendetta\n", "tempered with the ideal of justice. He trains himself physically and\n", "intellectually, crafts a bat-inspired persona, and monitors the Gotham\n", "streets at night. Kane, Finger, and other creators accompanied Batman\n", "with supporting characters, including his sidekicks Robin and Batgirl;\n", "allies Alfred Pennyworth, James Gordon, and Catwoman; and foes such as\n", "the Penguin, the Riddler, Two-Face, and his archenemy, the Joker.\n" ] } ], "source": [ "ai = chat_gpt(client, user_question, es_augment_docs)\n", "print(f\"User Question: \\n{user_question}\\n\")\n", "print(\"AI response:\")\n", "print(wrap_text(ai.choices[0].message.content))" ] }, { "cell_type": "markdown", "metadata": { "id": "t7RmurdZNPg-" }, "source": [ "### <font color=Orange>Lab 4.3 - Full RAG Application with UI</font>\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Lk7r-2ZAJJkt" }, "source": [ "#### Setup\n", "Running this cell will write a file named `app.py` into the Colab environment.\n", "\n", "This is the code needed to run the RAG application" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "Wc02OlkpOSd2", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "1faf8437-b41d-45e1-865e-462f22c00636" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Writing app.py\n" ] } ], "source": [ "%%writefile app.py\n", "\n", "import os\n", "import streamlit as st\n", "import openai\n", "from openai import OpenAI\n", "from elasticsearch import Elasticsearch\n", "import elasticapm\n", "import base64\n", "from elasticsearch_llm_cache.elasticsearch_llm_cache import ElasticsearchLLMCache\n", "import time\n", "import json\n", "import textwrap\n", "\n", "######################################\n", "# Streamlit Configuration\n", "st.set_page_config(layout=\"wide\")\n", "\n", "\n", "# wrap text when printing, because colab scrolls output to the right too much\n", "def wrap_text(text, width):\n", " wrapped_text = textwrap.wrap(text, width)\n", " return '\\n'.join(wrapped_text)\n", "\n", "\n", "@st.cache_data()\n", "def get_base64(bin_file):\n", " with open(bin_file, 'rb') as f:\n", " data = f.read()\n", " return base64.b64encode(data).decode()\n", "\n", "\n", "def set_background(png_file):\n", " bin_str = get_base64(png_file)\n", " page_bg_img = '''\n", " <style>\n", " .stApp {\n", " background-image: url(\"data:image/png;base64,%s\");\n", " background-size: cover;\n", " }\n", " </style>\n", " ''' % bin_str\n", " st.markdown(page_bg_img, unsafe_allow_html=True)\n", " return\n", "\n", "\n", "set_background('images/background-dark2.jpeg')\n", "\n", "\n", "######################################\n", "\n", "######################################\n", "# Sidebar Options\n", "def sidebar_bg(side_bg):\n", " side_bg_ext = 'png'\n", " st.markdown(\n", " f\"\"\"\n", " <style>\n", " [data-testid=\"stSidebar\"] > div:first-child {{\n", " background: url(data:image/{side_bg_ext};base64,{base64.b64encode(open(side_bg, \"rb\").read()).decode()});\n", " }}\n", " </style>\n", " \"\"\",\n", " unsafe_allow_html=True,\n", " )\n", "\n", "\n", "side_bg = './images/sidebar2-dark.png'\n", "sidebar_bg(side_bg)\n", "\n", "# sidebar logo\n", "st.markdown(\n", " \"\"\"\n", " <style>\n", " [data-testid=stSidebar] [data-testid=stImage]{\n", " text-align: center;\n", " display: block;\n", " margin-left: auto;\n", " margin-right: auto;\n", " width: 100%;\n", " }\n", " </style>\n", " \"\"\", unsafe_allow_html=True\n", ")\n", "\n", "with st.sidebar:\n", " st.image(\"images/elastic_logo_transp_100.png\")\n", "\n", "######################################\n", "# expander markdown\n", "st.markdown(\n", " '''\n", " <style>\n", " .streamlit-expanderHeader {\n", " background-color: gray;\n", " color: black; # Adjust this for expander header color\n", " }\n", " .streamlit-expanderContent {\n", " background-color: white;\n", " color: black; # Expander content color\n", " }\n", " </style>\n", " ''',\n", " unsafe_allow_html=True\n", ")\n", "\n", "######################################\n", "\n", "# Configure OpenAI client\n", "openai.api_key = os.environ['OPENAI_API_KEY']\n", "openai.api_base = os.environ['OPENAI_API_BASE']\n", "openai.default_model = os.environ['OPENAI_API_ENGINE']\n", "openai.verify_ssl_certs = False\n", "client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)\n", "\n", "\n", "# Initialize Elasticsearch and APM clients\n", "# Configure APM and Elasticsearch clients\n", "@st.cache_resource\n", "def initElastic():\n", " os.environ['ELASTIC_APM_SERVICE_NAME'] = \"genai_workshop_v2_lab_2-2\"\n", " apmclient = elasticapm.Client()\n", " elasticapm.instrument()\n", "\n", " if 'ELASTIC_CLOUD_ID' in os.environ:\n", " es = Elasticsearch(\n", " cloud_id=os.environ['ELASTIC_CLOUD_ID'],\n", " api_key=(os.environ['ELASTIC_APIKEY_ID']),\n", " request_timeout=30\n", " )\n", " else:\n", " es = Elasticsearch(\n", " os.environ['ELASTIC_URL'],\n", " basic_auth=(os.environ['ELASTIC_USER'], os.environ['ELASTIC_PASSWORD']),\n", " request_timeout=30\n", " )\n", "\n", " if os.environ['ELASTIC_PROXY'] != \"True\":\n", " openai.api_type = os.environ['OPENAI_API_TYPE']\n", " openai.api_version = os.environ['OPENAI_API_VERSION']\n", "\n", " return apmclient, es\n", "\n", "\n", "apmclient, es = initElastic()\n", "\n", "# Set our data index\n", "index = os.environ['ELASTIC_INDEX_DOCS']\n", "\n", "###############################################################\n", "# Similarity Cache functions\n", "# move to env if time\n", "cache_index = \"wikipedia-cache\"\n", "\n", "\n", "def clear_es_cache(es):\n", " print('clearing cache')\n", " match_all_query = {\"query\": {\"match_all\": {}}}\n", " clear_response = es.delete_by_query(index=cache_index, body=match_all_query)\n", " return clear_response\n", "\n", "\n", "@elasticapm.capture_span(\"cache_search\")\n", "def cache_query(cache, prompt_text, similarity_threshold=0.5):\n", " hit = cache.query(prompt_text=prompt_text, similarity_threshold=similarity_threshold)\n", "\n", " if hit:\n", " st.sidebar.markdown('`Cache Match Found`')\n", " else:\n", " st.sidebar.markdown('`Cache Miss`')\n", "\n", " return hit\n", "\n", "\n", "@elasticapm.capture_span(\"add_to_cache\")\n", "def add_to_cache(cache, prompt, response):\n", " st.sidebar.markdown('`Adding response to cache`')\n", " print('adding to cache')\n", " print(prompt)\n", " print(response)\n", " resp = cache.add(prompt=prompt, response=response)\n", " st.markdown(resp)\n", " return resp\n", "\n", "\n", "def init_elastic_cache():\n", " # Init Elasticsearch Cache\n", " # Only want to attempt to create the index on first run\n", " cache = ElasticsearchLLMCache(es_client=es,\n", " index_name=cache_index,\n", " create_index=False # setting only because of Streamlit behavior\n", " )\n", " st.sidebar.markdown('`creating Elasticsearch Cache`')\n", "\n", " if \"index_created\" not in st.session_state:\n", "\n", " st.sidebar.markdown('`running create_index`')\n", " cache.create_index(768)\n", "\n", " # Set the flag so it doesn't run every time\n", " st.session_state.index_created = True\n", " else:\n", " st.sidebar.markdown('`index already created, skipping`')\n", "\n", " return cache\n", "\n", "\n", "def calc_similarity(score, func_type='dot_product'):\n", " if func_type == 'dot_product':\n", " return (score + 1) / 2\n", " elif func_type == 'cosine':\n", " return (1 + score) / 2\n", " elif func_type == 'l2_norm':\n", " return 1 / (1 + score ^ 2)\n", " else:\n", " return score\n", "\n", "\n", "# l2_norm: sqrt((1 / _score) - 1)\n", "# cosine: (2 * _score) - 1\n", "# dot_product: (2 * _score) - 1\n", "# max_inner_product:\n", "# _score < 1: 1 - (1 / _score)\n", "# _score >= 1: _score - 1\n", "\n", "\n", "###############################################################\n", "\n", "\n", "def get_bm25_query(query_text, augment_method):\n", " if augment_method == \"Full Text\":\n", " return {\n", " \"match\": {\n", " \"text\": query_text\n", " }\n", " }\n", " elif augment_method == \"Matching Chunk\":\n", " return {\n", " \"nested\": {\n", " \"path\": \"passages\",\n", " \"query\": {\n", " \"bool\": {\n", " \"must\": [\n", " {\n", " \"match\": {\n", " \"passages.text\": query_text\n", " }\n", " }\n", " ]\n", " }\n", " },\n", " \"inner_hits\": {\n", " \"_source\": False,\n", " \"fields\": [\n", " \"passages.text\"\n", " ]\n", " }\n", "\n", " }\n", " }\n", "\n", "\n", "# Run an Elasticsearch query using BM25 relevance scoring\n", "@elasticapm.capture_span(\"bm25_search\")\n", "def search_bm25(query_text,\n", " es,\n", " size=1,\n", " augment_method=\"Full Text\",\n", " use_hybrid=False # always false - use semantic opt for hybrid\n", " ):\n", " fields = [\n", " \"text\",\n", " \"title\",\n", " ]\n", "\n", " resp = es.search(index=index,\n", " query=get_bm25_query(query_text, augment_method),\n", " fields=fields,\n", " size=size,\n", " source=False)\n", " # print(resp)\n", " body = resp\n", " url = 'nothing'\n", "\n", " return body, url\n", "\n", "\n", "@elasticapm.capture_span(\"knn_search\")\n", "def search_knn(query_text,\n", " es,\n", " size=1,\n", " augment_method=\"Full Text\",\n", " use_hybrid=False\n", " ):\n", " fields = [\n", " \"title\",\n", " \"text\"\n", " ]\n", "\n", " knn = {\n", " \"inner_hits\": {\n", " \"_source\": False,\n", " \"fields\": [\n", " \"passages.text\"\n", " ]\n", " },\n", " \"field\": \"passages.embeddings\",\n", " \"k\": size,\n", " \"num_candidates\": 100,\n", " \"query_vector_builder\": {\n", " \"text_embedding\": {\n", " \"model_id\": \"sentence-transformers__all-distilroberta-v1\",\n", " \"model_text\": query_text\n", " }\n", " }\n", " }\n", "\n", " rank = {\"rrf\": {}} if use_hybrid else None\n", "\n", " # need to get the bm25 query if we are using hybrid\n", " if use_hybrid:\n", " print('using hybrid with augment method %s' % augment_method)\n", " query = get_bm25_query(query_text, augment_method)\n", " print(query)\n", " if augment_method == \"Matching Chunk\":\n", " del query['nested']['inner_hits']\n", " else:\n", " print('not using hybrid')\n", " query = None\n", "\n", " print(query)\n", " print(knn)\n", "\n", " resp = es.search(index=index,\n", " knn=knn,\n", " query=query,\n", " fields=fields,\n", " size=size,\n", " rank=rank,\n", " source=False)\n", "\n", " return resp, None\n", "\n", "\n", "def truncate_text(text, max_tokens):\n", " tokens = text.split()\n", " if len(tokens) <= max_tokens:\n", " return text\n", "\n", " return ' '.join(tokens[:max_tokens])\n", "\n", "\n", "def build_text_obj(resp, aug_method):\n", "\n", " tobj = {}\n", "\n", " for hit in resp['hits']['hits']:\n", " # tobj[hit['fields']['title'][0]] = []\n", " title = hit['fields']['title'][0]\n", " tobj.setdefault(title, [])\n", "\n", " if aug_method == \"Matching Chunk\":\n", " print('hit')\n", " print(hit)\n", " # tobj['passages'] = []\n", " for ihit in hit['inner_hits']['passages']['hits']['hits']:\n", " tobj[title].append(\n", " {'passage': ihit['fields']['passages'][0]['text'][0],\n", " '_score': ihit['_score']}\n", " )\n", " elif aug_method == \"Full Text\":\n", " tobj[title].append(\n", " hit['fields']\n", " )\n", "\n", " return tobj\n", "\n", "\n", "def generate_response(query,\n", " es,\n", " search_method,\n", " custom_prompt,\n", " negative_response,\n", " show_prompt, size=1,\n", " augment_method=\"Full Text\",\n", " use_hybrid=False,\n", " show_es_response=True,\n", " show_es_augment=True,\n", " ):\n", "\n", " # Perform the search based on the specified method\n", " search_functions = {\n", " 'bm25': {'method': search_bm25, 'display': 'Lexical Search'},\n", " 'knn': {'method': search_knn, 'display': 'Semantic Search'}\n", " }\n", " search_func = search_functions.get(search_method)['method']\n", " if not search_func:\n", " raise ValueError(f\"Invalid search method: {search_method}\")\n", "\n", " # Perform the search and format the docs\n", " response, url = search_func(query, es, size, augment_method, use_hybrid)\n", " es_time = time.time()\n", " augment_text = build_text_obj(response, augment_method)\n", "\n", " res_col1, res_col2 = st.columns(2)\n", " # Display the search results from ES\n", " with res_col2:\n", " st.header(':rainbow[Elasticsearch Response]')\n", " st.subheader(':orange[Search Settings]')\n", " st.write(':gray[Search Method:] :blue[%s]' % search_functions.get(search_method)['display'])\n", " st.write(':gray[Size Setting:] :blue[%s]' % size)\n", " st.write(':gray[Augment Setting:] :blue[%s]' % augment_method)\n", " st.write(':gray[Using Hybrid:] :blue[%s]' % (\n", " 'Not Applicable with Lexical' if search_method == 'bm25' else use_hybrid))\n", "\n", " st.subheader(':green[Augment Chunk(s) from Elasticsearch]')\n", " if show_es_augment:\n", " st.json(dict(augment_text))\n", " else:\n", " st.write(':blue[Show Augment Disabled]')\n", "\n", " st.subheader(':violet[Elasticsearch Response]')\n", " if show_es_response:\n", " st.json(dict(response))\n", " else:\n", " st.write(':blue[Response Received]')\n", "\n", " formatted_prompt = custom_prompt.replace(\"$query\", query).replace(\"$response\", str(augment_text)).replace(\n", " \"$negResponse\", negative_response)\n", "\n", " with res_col1:\n", " st.header(':orange[GenAI Response]')\n", "\n", " chat_response = chat_gpt(formatted_prompt, system_prompt=\"You are a helpful assistant.\")\n", "\n", " # Display assistant response in chat message container\n", " with st.chat_message(\"assistant\"):\n", " message_placeholder = st.empty()\n", " full_response = \"\"\n", " for chunk in chat_response.split():\n", " full_response += chunk + \" \"\n", " time.sleep(0.02)\n", " # Add a blinking cursor to simulate typing\n", " message_placeholder.markdown(full_response + \"▌\")\n", " message_placeholder.markdown(full_response)\n", "\n", " # Display results\n", " if show_prompt:\n", " st.text(\"Full prompt sent to ChatGPT:\")\n", " st.text(wrap_text(formatted_prompt, 70))\n", "\n", " if negative_response not in chat_response:\n", " pass\n", " else:\n", " chat_response = None\n", "\n", " return es_time, chat_response\n", "\n", "\n", "def chat_gpt(user_prompt, system_prompt):\n", " \"\"\"\n", " Generates a response from ChatGPT based on the given user and system prompts.\n", " \"\"\"\n", " max_tokens = 1024\n", " max_context_tokens = 4000\n", " safety_margin = 5\n", "\n", " # Truncate the prompt content to fit within the model's context length\n", " truncated_prompt = truncate_text(user_prompt, max_context_tokens - max_tokens - safety_margin)\n", "\n", " # Prepare the messages for the ChatGPT API\n", " messages = [{\"role\": \"system\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": truncated_prompt}]\n", "\n", " # Add APM metadata and return the response content\n", " elasticapm.set_custom_context({'model': openai.default_model, 'prompt': user_prompt})\n", " # return response[\"choices\"][0][\"message\"][\"content\"]\n", "\n", " full_response = \"\"\n", " for response in client.chat.completions.create(\n", " model=openai.default_model,\n", " temperature=0,\n", " messages=messages,\n", " stream=True\n", " ):\n", " full_response += (response.choices[0].delta.content or \"\")\n", "\n", " return full_response\n", "\n", "\n", "# Main chat form\n", "st.title(\"Wikipedia RAG Demo Platform\")\n", "\n", "# Define the default prompt and negative response\n", "default_prompt_intro = \"Answer this question:\"\n", "default_response_instructions = (\"using only the information from the wikipedia documents included and nothing \"\n", " \"else.\\nwikipedia_docs: $response\\n\")\n", "default_negative_response = (\"If the answer is not provided in the included documentation. You are to ONLY reply with \"\n", " \"'I'm unable to answer the question based on the information I have from wikipedia' and \"\n", " \"nothing else.\")\n", "\n", "with st.form(\"chat_form\"):\n", " query = st.text_input(\"Ask a question in Wikipedia:\",\n", " placeholder='Sample Question: Who is Batman?')\n", "\n", " opt_col1, opt_col2 = st.columns(2)\n", " with opt_col1:\n", " with st.expander(\"Customize Prompt Template\"):\n", " prompt_intro = st.text_area(\"Introduction/context of the prompt:\", value=default_prompt_intro)\n", " prompt_query_placeholder = st.text_area(\"Placeholder for the user's query:\", value=\"$query\")\n", " prompt_response_placeholder = st.text_area(\"Placeholder for the Elasticsearch response:\",\n", " value=default_response_instructions)\n", " prompt_negative_response = st.text_area(\"Negative response placeholder:\", value=default_negative_response)\n", " prompt_closing = st.text_area(\"Closing remarks of the prompt:\",\n", " value=\"Format the answer in complete markdown code format.\")\n", "\n", " combined_prompt = f\"{prompt_intro}\\n{prompt_query_placeholder}\\n{prompt_response_placeholder}\\n{prompt_negative_response}\\n{prompt_closing}\"\n", " st.text_area(\"Preview of your custom prompt:\", value=combined_prompt, disabled=True)\n", "\n", " with opt_col2:\n", " with st.expander(\"Retrieval Search and Display Options\"):\n", " st.subheader(\"Retrieval Options\")\n", " ret_1, ret_2 = st.columns(2)\n", " with ret_1:\n", " search_method = st.radio(\"Search Method\", (\"Semantic Search\", \"Lexical Search\"))\n", " augment_method = st.radio(\"Augment Method\", (\"Full Text\", \"Matching Chunk\"))\n", " with ret_2:\n", " # TODO this should update the title based on the augment_method\n", " doc_count_title = \"Number of docs or chunks to Augment with\" if augment_method == \"Full Text\" else \"Number of Matching Chunks to Retrieve\"\n", " doc_count = st.slider(doc_count_title, min_value=1, max_value=5, value=1)\n", "\n", " use_hybrid = st.checkbox('Use Hybrid Search')\n", "\n", " st.divider()\n", "\n", " st.subheader(\"Display Options\")\n", " show_es_augment = st.checkbox('Show Elasticsearch Augment Text', value=True)\n", " show_es_response = st.checkbox('Show Elasticsearch Response', value=True)\n", " show_full_prompt = st.checkbox('Show Full Prompt Sent to LLM')\n", "\n", " st.divider()\n", "\n", " st.subheader(\"Caching Options\")\n", " cache_1, cache_2 = st.columns(2)\n", " with cache_1:\n", " use_cache = st.checkbox('Use Similarity Cache')\n", " # Slider for adjusting similarity threshold\n", " similarity_threshold_selection = st.slider(\n", " \"Select Similarity Threshold (dot_product - Higher Similarity means closer)\",\n", " min_value=0.0, max_value=2.0,\n", " value=0.5, step=0.01)\n", "\n", " with cache_2:\n", " clear_cache_butt = st.form_submit_button(':red[Clear Similarity Cache]')\n", "\n", " col1, col2 = st.columns(2)\n", " with col1:\n", " answer_button = st.form_submit_button(\"Find my answer!\")\n", "\n", "# Clear Cache Button\n", "if clear_cache_butt:\n", " st.session_state.clear_cache_clicked = True\n", "\n", "# Confirmation step\n", "if st.session_state.get(\"clear_cache_clicked\", False):\n", " apmclient.begin_transaction(\"clear_cache\")\n", " elasticapm.label(action=\"clear_cache\")\n", "\n", " # Start timing\n", " start_time = time.time()\n", "\n", " if st.button(\":red[Confirm Clear Cache]\"):\n", " print('clear cache clicked')\n", " # TODO if index doesn't exist, catch exception then create it\n", " response = clear_es_cache(es)\n", " st.success(\"Cache cleared successfully!\", icon=\"🤯\")\n", " st.session_state.clear_cache_clicked = False # Reset the state\n", "\n", " apmclient.end_transaction(\"clear_cache\", \"success\")\n", "\n", "if answer_button:\n", " search_method = \"knn\" if search_method == \"Semantic Search\" else \"bm25\"\n", "\n", " apmclient.begin_transaction(\"query\")\n", " elasticapm.label(search_method=search_method)\n", " elasticapm.label(query=query)\n", "\n", " # Start timing\n", " start_time = time.time()\n", "\n", " if query == \"\":\n", " st.error(\"Please enter a question in the Question Box.\")\n", " apmclient.end_transaction(\"query\", \"failure\")\n", "\n", " else:\n", " if use_cache:\n", " cache = init_elastic_cache()\n", "\n", " # check the llm cache first\n", " st.sidebar.markdown('`Checking ES Cache`')\n", " cache_check = cache_query(cache,\n", " prompt_text=query,\n", " similarity_threshold=similarity_threshold_selection\n", " )\n", " # st.markdown(cache_check)\n", " else:\n", " cache_check = None\n", " st.sidebar.markdown('`Skipping ES Cache`')\n", "\n", " try:\n", "\n", " if cache_check:\n", " es_time = time.time()\n", " st.sidebar.markdown('`cache match, using cached results`')\n", " st.subheader('Response from Cache')\n", " s_score = calc_similarity(cache_check['_score'], func_type='dot_product')\n", " st.code(f\"Similarity Value: {s_score:.5f}\")\n", "\n", " # Display response from LLM\n", " st.header('LLM Response')\n", " # st.markdown(cache_check['response'][0])\n", " with st.chat_message(\"assistant\"):\n", " message_placeholder = st.empty()\n", " full_response = \"\"\n", " for chunk in cache_check['response'][0].split():\n", " full_response += chunk + \" \"\n", " time.sleep(0.02)\n", " # Add a blinking cursor to simulate typing\n", " message_placeholder.markdown(full_response + \"▌\")\n", " message_placeholder.markdown(full_response)\n", "\n", " llmAnswer = None # no need to recache the answer\n", "\n", " else:\n", " # Use combined_prompt and show_full_prompt as arguments\n", " es_time, llmAnswer = generate_response(query,\n", " es,\n", " search_method,\n", " combined_prompt,\n", " prompt_negative_response,\n", " show_full_prompt,\n", " doc_count,\n", " augment_method,\n", " use_hybrid,\n", " show_es_response,\n", " show_es_augment,\n", " )\n", " apmclient.end_transaction(\"query\", \"success\")\n", "\n", " if use_cache and llmAnswer:\n", " if \"I'm unable to answer the question\" in llmAnswer:\n", " st.sidebar.markdown('`unable to answer, not adding to cache`')\n", " else:\n", " st.sidebar.markdown('`adding prompt and response to cache`')\n", " add_to_cache(cache, query, llmAnswer)\n", "\n", " # End timing and print the elapsed time\n", " elapsed_time = time.time() - start_time\n", " es_elapsed_time = es_time - start_time\n", "\n", " ct1, ct2 = st.columns(2)\n", " with ct1:\n", " st.subheader(\"GenAI Time taken: :red[%.2f seconds]\" % elapsed_time)\n", "\n", " with ct2:\n", " st.subheader(\"ES Query Time taken: :green[%.2f seconds]\" % es_elapsed_time)\n", "\n", " except Exception as e:\n", " st.error(f\"An error occurred: {str(e)}\")\n", " apmclient.end_transaction(\"query\", \"failure\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Wu0KfS0ESf6e" }, "source": [ "#### Run the RAG Application\n", "Running this cell will start local tunnel and generate a random URL\n", "\n", "1. Run this cell\n", "2. Copy the IP address on the first line\n", "3. Open the generated URL\n", "4. Paste the copied IP into the input box *Endpoint IP*\n", "\n", "This will then start the Rag Application" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cHIHFID3NBXa" }, "outputs": [], "source": [ "!streamlit run app.py &>/content/logs.txt & npx localtunnel --port 8501 & curl ipv4.icanhazip.com" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 0 }