src/frontends/streamlit/frontend/streamlit_app.py (203 lines of code) (raw):
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# mypy: disable-error-code="arg-type"
import json
import uuid
from collections.abc import Sequence
from functools import partial
from typing import Any
import streamlit as st
from langchain_core.messages import HumanMessage
from streamlit_feedback import streamlit_feedback
from frontend.side_bar import SideBar
from frontend.style.app_markdown import MARKDOWN_STR
from frontend.utils.local_chat_history import LocalChatMessageHistory
from frontend.utils.message_editing import MessageEditing
from frontend.utils.multimodal_utils import format_content, get_parts_from_files
from frontend.utils.stream_handler import Client, StreamHandler, get_chain_response
USER = "my_user"
EMPTY_CHAT_NAME = "Empty chat"
def setup_page() -> None:
"""Configure the Streamlit page settings."""
st.set_page_config(
page_title="Playground",
layout="wide",
initial_sidebar_state="auto",
menu_items=None,
)
st.title("Playground")
st.markdown(MARKDOWN_STR, unsafe_allow_html=True)
def initialize_session_state() -> None:
"""Initialize the session state with default values."""
if "user_chats" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4())
st.session_state.uploader_key = 0
st.session_state.run_id = None
st.session_state.user_id = USER
st.session_state["gcs_uris_to_be_sent"] = ""
st.session_state.modified_prompt = None
st.session_state.session_db = LocalChatMessageHistory(
session_id=st.session_state["session_id"],
user_id=st.session_state["user_id"],
)
st.session_state.user_chats = (
st.session_state.session_db.get_all_conversations()
)
st.session_state.user_chats[st.session_state["session_id"]] = {
"title": EMPTY_CHAT_NAME,
"messages": [],
}
def display_messages() -> None:
"""Display all messages in the current chat session."""
messages = st.session_state.user_chats[st.session_state["session_id"]]["messages"]
tool_calls_map = {} # Map tool_call_id to tool call input
for i, message in enumerate(messages):
if message["type"] in ["ai", "human"] and message["content"]:
display_chat_message(message, i)
elif message.get("tool_calls"):
# Store each tool call input mapped by its ID
for tool_call in message["tool_calls"]:
tool_calls_map[tool_call["id"]] = tool_call
elif message["type"] == "tool":
# Look up the corresponding tool call input by ID
tool_call_id = message["tool_call_id"]
if tool_call_id in tool_calls_map:
display_tool_output(tool_calls_map[tool_call_id], message)
else:
st.error(f"Could not find tool call input for ID: {tool_call_id}")
else:
st.error(f"Unexpected message type: {message['type']}")
st.write("Full messages list:", messages)
raise ValueError(f"Unexpected message type: {message['type']}")
def display_chat_message(message: dict[str, Any], index: int) -> None:
"""Display a single chat message with edit, refresh, and delete options."""
chat_message = st.chat_message(message["type"])
with chat_message:
st.markdown(format_content(message["content"]), unsafe_allow_html=True)
col1, col2, col3 = st.columns([2, 2, 94])
display_message_buttons(message, index, col1, col2, col3)
def display_message_buttons(
message: dict[str, Any], index: int, col1: Any, col2: Any, col3: Any
) -> None:
"""Display edit, refresh, and delete buttons for a chat message."""
edit_button = f"{index}_edit"
refresh_button = f"{index}_refresh"
delete_button = f"{index}_delete"
content = (
message["content"]
if isinstance(message["content"], str)
else message["content"][-1]["text"]
)
with col1:
st.button(label="✎", key=edit_button, type="primary")
if message["type"] == "human":
with col2:
st.button(
label="⟳",
key=refresh_button,
type="primary",
on_click=partial(MessageEditing.refresh_message, st, index, content),
)
with col3:
st.button(
label="X",
key=delete_button,
type="primary",
on_click=partial(MessageEditing.delete_message, st, index),
)
if st.session_state[edit_button]:
st.text_area(
"Edit your message:",
value=content,
key=f"edit_box_{index}",
on_change=partial(MessageEditing.edit_message, st, index, message["type"]),
)
def display_tool_output(
tool_call_input: dict[str, Any], tool_call_output: dict[str, Any]
) -> None:
"""Display the input and output of a tool call in an expander."""
tool_expander = st.expander(label="Tool Calls:", expanded=False)
with tool_expander:
msg = (
f"\n\nEnding tool: `{tool_call_input}` with\n **args:**\n"
f"```\n{json.dumps(tool_call_input, indent=2)}\n```\n"
f"\n\n**output:**\n "
f"```\n{json.dumps(tool_call_output, indent=2)}\n```"
)
st.markdown(msg, unsafe_allow_html=True)
def handle_user_input(side_bar: SideBar) -> None:
"""Process user input, generate AI response, and update chat history."""
prompt = st.chat_input() or st.session_state.modified_prompt
if prompt:
st.session_state.modified_prompt = None
parts = get_parts_from_files(
upload_gcs_checkbox=st.session_state.checkbox_state,
uploaded_files=side_bar.uploaded_files,
gcs_uris=side_bar.gcs_uris,
)
st.session_state["gcs_uris_to_be_sent"] = ""
parts.append({"type": "text", "text": prompt})
st.session_state.user_chats[st.session_state["session_id"]]["messages"].append(
HumanMessage(content=parts).model_dump()
)
display_user_input(parts)
generate_ai_response(
remote_agent_engine_id=side_bar.remote_agent_engine_id,
agent_callable_path=side_bar.agent_callable_path,
url=side_bar.url_input_field,
authenticate_request=side_bar.should_authenticate_request,
)
update_chat_title()
if len(parts) > 1:
st.session_state.uploader_key += 1
st.rerun()
def display_user_input(parts: Sequence[dict[str, Any]]) -> None:
"""Display the user's input in the chat interface."""
human_message = st.chat_message("human")
with human_message:
existing_user_input = format_content(parts)
st.markdown(existing_user_input, unsafe_allow_html=True)
def generate_ai_response(
remote_agent_engine_id: str | None = None,
agent_callable_path: str | None = None,
url: str | None = None,
authenticate_request: bool = False,
) -> None:
"""Generate and display the AI's response to the user's input."""
ai_message = st.chat_message("ai")
with ai_message:
status = st.status("Generating answer🤖")
stream_handler = StreamHandler(st=st)
client = Client(
remote_agent_engine_id=remote_agent_engine_id,
agent_callable_path=agent_callable_path,
url=url,
authenticate_request=authenticate_request,
)
get_chain_response(st=st, client=client, stream_handler=stream_handler)
status.update(label="Finished!", state="complete", expanded=False)
def update_chat_title() -> None:
"""Update the chat title if it's currently empty."""
if (
st.session_state.user_chats[st.session_state["session_id"]]["title"]
== EMPTY_CHAT_NAME
):
st.session_state.session_db.set_title(
st.session_state.user_chats[st.session_state["session_id"]]
)
st.session_state.session_db.upsert_session(
st.session_state.user_chats[st.session_state["session_id"]]
)
def display_feedback(side_bar: SideBar) -> None:
"""Display a feedback component and log the feedback if provided."""
if st.session_state.run_id is not None:
feedback = streamlit_feedback(
feedback_type="faces",
optional_text_label="[Optional] Please provide an explanation",
key=f"feedback-{st.session_state.run_id}",
)
if feedback is not None:
client = Client(
remote_agent_engine_id=side_bar.remote_agent_engine_id,
agent_callable_path=side_bar.agent_callable_path,
url=side_bar.url_input_field,
authenticate_request=side_bar.should_authenticate_request,
)
client.log_feedback(
feedback_dict=feedback,
run_id=st.session_state.run_id,
)
def main() -> None:
"""Main function to set up and run the Streamlit app."""
setup_page()
initialize_session_state()
side_bar = SideBar(st=st)
side_bar.init_side_bar()
display_messages()
handle_user_input(side_bar=side_bar)
display_feedback(side_bar=side_bar)
if __name__ == "__main__":
main()