components/frontend_streamlit/src/utils.py (148 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Streamlit app utils file
"""
import re
import ast
import streamlit as st
import logging
from config import API_BASE_URL, APP_BASE_PATH
from streamlit.runtime.scriptrunner import RerunData, RerunException
from streamlit.source_util import get_pages
from urllib.parse import urlparse
from api import validate_auth_token
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
def http_navigate_to(url, query_params=None):
""" Navigate to a specific URL. However, this will lose all session_state. """
query_params_from_session = ["auth_token", "debug"]
query_params_list = [
(x + "=" +str(st.session_state.get(x, ""))) \
for x in query_params_from_session]
logging.info("http_navigate_to query params %s", query_params_list)
if query_params:
for key, value in query_params.items():
query_params_list.append(f"{key}={value}")
query_param_str = "&".join(query_params_list)
url_ojb = urlparse(url)
url = f"{url}?{query_param_str}&{url_ojb.query}"
nav_script = f"""
<meta http-equiv="refresh" content="0; url='{url}'">
"""
st.write(nav_script, unsafe_allow_html=True)
def navigate_to(page_name):
""" Navigate to a specific page and keep session_state. """
def standardize_name(name: str) -> str:
return name.lower().replace("_", " ")
page_name = standardize_name(page_name)
pages = get_pages("main.py")
for page_hash, config in pages.items():
if standardize_name(config["page_name"]) == page_name:
raise RerunException(
RerunData(
page_script_hash=page_hash,
page_name=page_name,
)
)
def init_session_state():
query_params = st.query_params
# If set query_param "debug=true"
if query_params.get("debug", "").lower() == "true":
st.session_state.debug = True
error_msg = query_params.get("error_msg", "")
if error_msg:
st.session_state.error_msg = error_msg
# Try to get a state var from query parameter.
states_to_init = [
"auth_token", "chat_id", "agent_name", "debug", "chat_llm_type",
"default_route"
]
for state_name in states_to_init:
if not st.session_state.get(state_name, None):
st.session_state[state_name] = query_params.get(state_name, "")
print(f"st.session_state: {st.session_state}")
def reset_session_state():
""" Reset critial session states. """
st.session_state.landing_user_input = None
st.session_state.chat_id = None
st.session_state.chat_llm_type = None
st.session_state.messages = []
st.session_state.error_msg = None
def init_page(redirect_to_without_auth=True):
""" Initial setup at each page. """
init_session_state()
error_msg = st.session_state.get("error_msg", "")
if error_msg:
st.error(error_msg)
# Check auth token.
auth_token = st.session_state.get("auth_token", None)
if redirect_to_without_auth:
# If still not getting auth_token, redirect back to Login page.
if not auth_token:
navigate_to("Login")
st.stop()
if not validate_auth_token():
st.session_state.error_msg = \
"Unauthorized or session expired. " \
f"Please [Login]({APP_BASE_PATH}/Login) again."
#./main.py is used as an entrypoint for the build,
# which creates a page that duplicates the Login page named "main".
hide_pages(["main", "Custom_Chat"])
api_base_url = API_BASE_URL
st.session_state.api_base_url = api_base_url.rstrip("/")
logging.info("st.session_state.api_base_url = %s",
st.session_state.api_base_url)
def hide_pages(hidden_pages: list[str]):
styling = ""
current_pages = get_pages("")
section_hidden = False
for idx, val in enumerate(current_pages.values()):
page_name = val.get("page_name")
if val.get("is_section"):
# Set whole section as hidden
section_hidden = page_name in hidden_pages
elif not val.get("in_section"):
# Reset whole section hiding if we hit a page thats not in a section
section_hidden = False
if page_name in hidden_pages or section_hidden:
styling += f"""
div[data-testid=\"stSidebarNav\"] li:nth-child({idx + 1}) {{
display: none;
}}
"""
styling = f"""
<style>
{styling}
</style>
"""
st.write(
styling,
unsafe_allow_html=True,
)
def format_ai_output(text):
if not isinstance(text, str):
return text
text = text.strip()
# Clean up ASCI code and text formatting code.
text = ansi_escape.sub("", text)
text = re.sub(r"\[1;3m", "\n", text)
text = re.sub(r"\[[\d;]+m", "", text)
# Reformat steps.
text = text.replace("> Entering new AgentExecutor chain",
"**Entering new AgentExecutor chain**")
text = text.replace("Task:", "- **Task**:")
text = text.replace("Observation:", "---\n**Observation**:")
text = text.replace("Thought:", "- **Thought**:")
text = text.replace("Action:", "- **Action**:")
text = text.replace("Action Input:", "- **Action Input**:")
text = text.replace("Route:", "- **Route**:")
text = text.replace("> Finished chain", "**Finished chain**")
return text
def print_json_content(key, value):
output = f" - **{key}**: ```{value}```"
if key == "action_input":
output = f"- **{key}**:"
if isinstance(value, dict):
for sub_key, sub_value in value.items():
output += f"\n - {sub_key}: {sub_value}"
else:
try:
value = ast.literal_eval(value.strip())
for sub_key, sub_value in value.items():
output += f"\n - {sub_key}: {sub_value}"
except (ValueError, SyntaxError):
output = f" - **{key}**: ```{value}```"
st.markdown(output)
def print_ai_output(ai_output):
if isinstance(ai_output, list):
for item in ai_output:
# with type
if "type" in item:
if item["type"] == "Observation":
st.markdown("---")
if not item.get("json_content") and \
not item.get("text_content", "").strip():
continue
st.markdown(f"**{item['type']}**")
if "json_content" in item:
for key, value in item["json_content"].items():
print_json_content(key, value)
else:
st.markdown(f" - {item['text_content']}")
# Without type
else:
text_content = item["text_content"].strip()
if text_content[:2] == "> ":
text_content = text_content[2:]
st.markdown(f"{text_content}")
elif isinstance(ai_output, str) and ai_output.strip() != "":
st.write(format_ai_output(ai_output))