computer-use-demo/computer_use_demo/streamlit.py (404 lines of code) (raw):

""" Entrypoint for streamlit, see https://docs.streamlit.io/ """ import asyncio import base64 import os import subprocess import traceback from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta from enum import StrEnum from functools import partial from pathlib import PosixPath from typing import cast, get_args import httpx import streamlit as st from anthropic import RateLimitError from anthropic.types.beta import ( BetaContentBlockParam, BetaTextBlockParam, BetaToolResultBlockParam, ) from streamlit.delta_generator import DeltaGenerator from computer_use_demo.loop import ( APIProvider, sampling_loop, ) from computer_use_demo.tools import ToolResult, ToolVersion PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { APIProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", } @dataclass(kw_only=True, frozen=True) class ModelConfig: tool_version: ToolVersion max_output_tokens: int default_output_tokens: int has_thinking: bool = False SONNET_3_5_NEW = ModelConfig( tool_version="computer_use_20241022", max_output_tokens=1024 * 8, default_output_tokens=1024 * 4, ) SONNET_3_7 = ModelConfig( tool_version="computer_use_20250124", max_output_tokens=128_000, default_output_tokens=1024 * 16, has_thinking=True, ) MODEL_TO_MODEL_CONF: dict[str, ModelConfig] = { "claude-3-7-sonnet-20250219": SONNET_3_7, } CONFIG_DIR = PosixPath("~/.anthropic").expanduser() API_KEY_FILE = CONFIG_DIR / "api_key" STREAMLIT_STYLE = """ <style> /* Highlight the stop button in red */ button[kind=header] { background-color: rgb(255, 75, 75); border: 1px solid rgb(255, 75, 75); color: rgb(255, 255, 255); } button[kind=header]:hover { background-color: rgb(255, 51, 51); } /* Hide the streamlit deploy button */ .stAppDeployButton { visibility: hidden; } </style> """ WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior" INTERRUPT_TEXT = "(user stopped or interrupted and wrote the following)" INTERRUPT_TOOL_ERROR = "human stopped or interrupted tool execution" class Sender(StrEnum): USER = "user" BOT = "assistant" TOOL = "tool" def setup_state(): if "messages" not in st.session_state: st.session_state.messages = [] if "api_key" not in st.session_state: # Try to load API key from file first, then environment st.session_state.api_key = load_from_storage("api_key") or os.getenv( "ANTHROPIC_API_KEY", "" ) if "provider" not in st.session_state: st.session_state.provider = ( os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC ) if "provider_radio" not in st.session_state: st.session_state.provider_radio = st.session_state.provider if "model" not in st.session_state: _reset_model() if "auth_validated" not in st.session_state: st.session_state.auth_validated = False if "responses" not in st.session_state: st.session_state.responses = {} if "tools" not in st.session_state: st.session_state.tools = {} if "only_n_most_recent_images" not in st.session_state: st.session_state.only_n_most_recent_images = 3 if "custom_system_prompt" not in st.session_state: st.session_state.custom_system_prompt = load_from_storage("system_prompt") or "" if "hide_images" not in st.session_state: st.session_state.hide_images = False if "token_efficient_tools_beta" not in st.session_state: st.session_state.token_efficient_tools_beta = False if "in_sampling_loop" not in st.session_state: st.session_state.in_sampling_loop = False def _reset_model(): st.session_state.model = PROVIDER_TO_DEFAULT_MODEL_NAME[ cast(APIProvider, st.session_state.provider) ] _reset_model_conf() def _reset_model_conf(): model_conf = ( SONNET_3_7 if "3-7" in st.session_state.model else MODEL_TO_MODEL_CONF.get(st.session_state.model, SONNET_3_5_NEW) ) # If we're in radio selection mode, use the selected tool version if hasattr(st.session_state, "tool_versions"): st.session_state.tool_version = st.session_state.tool_versions else: st.session_state.tool_version = model_conf.tool_version st.session_state.has_thinking = model_conf.has_thinking st.session_state.output_tokens = model_conf.default_output_tokens st.session_state.max_output_tokens = model_conf.max_output_tokens st.session_state.thinking_budget = int(model_conf.default_output_tokens / 2) async def main(): """Render loop for streamlit""" setup_state() st.markdown(STREAMLIT_STYLE, unsafe_allow_html=True) st.title("Claude Computer Use Demo") if not os.getenv("HIDE_WARNING", False): st.warning(WARNING_TEXT) with st.sidebar: def _reset_api_provider(): if st.session_state.provider_radio != st.session_state.provider: _reset_model() st.session_state.provider = st.session_state.provider_radio st.session_state.auth_validated = False provider_options = [option.value for option in APIProvider] st.radio( "API Provider", options=provider_options, key="provider_radio", format_func=lambda x: x.title(), on_change=_reset_api_provider, ) st.text_input("Model", key="model", on_change=_reset_model_conf) if st.session_state.provider == APIProvider.ANTHROPIC: st.text_input( "Anthropic API Key", type="password", key="api_key", on_change=lambda: save_to_storage("api_key", st.session_state.api_key), ) st.number_input( "Only send N most recent images", min_value=0, key="only_n_most_recent_images", help="To decrease the total tokens sent, remove older screenshots from the conversation", ) st.text_area( "Custom System Prompt Suffix", key="custom_system_prompt", help="Additional instructions to append to the system prompt. see computer_use_demo/loop.py for the base system prompt.", on_change=lambda: save_to_storage( "system_prompt", st.session_state.custom_system_prompt ), ) st.checkbox("Hide screenshots", key="hide_images") st.checkbox( "Enable token-efficient tools beta", key="token_efficient_tools_beta" ) versions = get_args(ToolVersion) st.radio( "Tool Versions", key="tool_versions", options=versions, index=versions.index(st.session_state.tool_version), on_change=lambda: setattr( st.session_state, "tool_version", st.session_state.tool_versions ), ) st.number_input("Max Output Tokens", key="output_tokens", step=1) st.checkbox("Thinking Enabled", key="thinking", value=False) st.number_input( "Thinking Budget", key="thinking_budget", max_value=st.session_state.max_output_tokens, step=1, disabled=not st.session_state.thinking, ) if st.button("Reset", type="primary"): with st.spinner("Resetting..."): st.session_state.clear() setup_state() subprocess.run("pkill Xvfb; pkill tint2", shell=True) # noqa: ASYNC221 await asyncio.sleep(1) subprocess.run("./start_all.sh", shell=True) # noqa: ASYNC221 if not st.session_state.auth_validated: if auth_error := validate_auth( st.session_state.provider, st.session_state.api_key ): st.warning(f"Please resolve the following auth issue:\n\n{auth_error}") return else: st.session_state.auth_validated = True chat, http_logs = st.tabs(["Chat", "HTTP Exchange Logs"]) new_message = st.chat_input( "Type a message to send to Claude to control the computer..." ) with chat: # render past chats for message in st.session_state.messages: if isinstance(message["content"], str): _render_message(message["role"], message["content"]) elif isinstance(message["content"], list): for block in message["content"]: # the tool result we send back to the Anthropic API isn't sufficient to render all details, # so we store the tool use responses if isinstance(block, dict) and block["type"] == "tool_result": _render_message( Sender.TOOL, st.session_state.tools[block["tool_use_id"]] ) else: _render_message( message["role"], cast(BetaContentBlockParam | ToolResult, block), ) # render past http exchanges for identity, (request, response) in st.session_state.responses.items(): _render_api_response(request, response, identity, http_logs) # render past chats if new_message: st.session_state.messages.append( { "role": Sender.USER, "content": [ *maybe_add_interruption_blocks(), BetaTextBlockParam(type="text", text=new_message), ], } ) _render_message(Sender.USER, new_message) try: most_recent_message = st.session_state["messages"][-1] except IndexError: return if most_recent_message["role"] is not Sender.USER: # we don't have a user message to respond to, exit early return with track_sampling_loop(): # run the agent sampling loop with the newest message st.session_state.messages = await sampling_loop( system_prompt_suffix=st.session_state.custom_system_prompt, model=st.session_state.model, provider=st.session_state.provider, messages=st.session_state.messages, output_callback=partial(_render_message, Sender.BOT), tool_output_callback=partial( _tool_output_callback, tool_state=st.session_state.tools ), api_response_callback=partial( _api_response_callback, tab=http_logs, response_state=st.session_state.responses, ), api_key=st.session_state.api_key, only_n_most_recent_images=st.session_state.only_n_most_recent_images, tool_version=st.session_state.tool_versions, max_tokens=st.session_state.output_tokens, thinking_budget=st.session_state.thinking_budget if st.session_state.thinking else None, token_efficient_tools_beta=st.session_state.token_efficient_tools_beta, ) def maybe_add_interruption_blocks(): if not st.session_state.in_sampling_loop: return [] # If this function is called while we're in the sampling loop, we can assume that the previous sampling loop was interrupted # and we should annotate the conversation with additional context for the model and heal any incomplete tool use calls result = [] last_message = st.session_state.messages[-1] previous_tool_use_ids = [ block["id"] for block in last_message["content"] if block["type"] == "tool_use" ] for tool_use_id in previous_tool_use_ids: st.session_state.tools[tool_use_id] = ToolResult(error=INTERRUPT_TOOL_ERROR) result.append( BetaToolResultBlockParam( tool_use_id=tool_use_id, type="tool_result", content=INTERRUPT_TOOL_ERROR, is_error=True, ) ) result.append(BetaTextBlockParam(type="text", text=INTERRUPT_TEXT)) return result @contextmanager def track_sampling_loop(): st.session_state.in_sampling_loop = True yield st.session_state.in_sampling_loop = False def validate_auth(provider: APIProvider, api_key: str | None): if provider == APIProvider.ANTHROPIC: if not api_key: return "Enter your Anthropic API key in the sidebar to continue." if provider == APIProvider.BEDROCK: import boto3 if not boto3.Session().get_credentials(): return "You must have AWS credentials set up to use the Bedrock API." if provider == APIProvider.VERTEX: import google.auth from google.auth.exceptions import DefaultCredentialsError if not os.environ.get("CLOUD_ML_REGION"): return "Set the CLOUD_ML_REGION environment variable to use the Vertex API." try: google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], ) except DefaultCredentialsError: return "Your google cloud credentials are not set up correctly." def load_from_storage(filename: str) -> str | None: """Load data from a file in the storage directory.""" try: file_path = CONFIG_DIR / filename if file_path.exists(): data = file_path.read_text().strip() if data: return data except Exception as e: st.write(f"Debug: Error loading {filename}: {e}") return None def save_to_storage(filename: str, data: str) -> None: """Save data to a file in the storage directory.""" try: CONFIG_DIR.mkdir(parents=True, exist_ok=True) file_path = CONFIG_DIR / filename file_path.write_text(data) # Ensure only user can read/write the file file_path.chmod(0o600) except Exception as e: st.write(f"Debug: Error saving {filename}: {e}") def _api_response_callback( request: httpx.Request, response: httpx.Response | object | None, error: Exception | None, tab: DeltaGenerator, response_state: dict[str, tuple[httpx.Request, httpx.Response | object | None]], ): """ Handle an API response by storing it to state and rendering it. """ response_id = datetime.now().isoformat() response_state[response_id] = (request, response) if error: _render_error(error) _render_api_response(request, response, response_id, tab) def _tool_output_callback( tool_output: ToolResult, tool_id: str, tool_state: dict[str, ToolResult] ): """Handle a tool output by storing it to state and rendering it.""" tool_state[tool_id] = tool_output _render_message(Sender.TOOL, tool_output) def _render_api_response( request: httpx.Request, response: httpx.Response | object | None, response_id: str, tab: DeltaGenerator, ): """Render an API response to a streamlit tab""" with tab: with st.expander(f"Request/Response ({response_id})"): newline = "\n\n" st.markdown( f"`{request.method} {request.url}`{newline}{newline.join(f'`{k}: {v}`' for k, v in request.headers.items())}" ) st.json(request.read().decode()) st.markdown("---") if isinstance(response, httpx.Response): st.markdown( f"`{response.status_code}`{newline}{newline.join(f'`{k}: {v}`' for k, v in response.headers.items())}" ) st.json(response.text) else: st.write(response) def _render_error(error: Exception): if isinstance(error, RateLimitError): body = "You have been rate limited." if retry_after := error.response.headers.get("retry-after"): body += f" **Retry after {str(timedelta(seconds=int(retry_after)))} (HH:MM:SS).** See our API [documentation](https://docs.anthropic.com/en/api/rate-limits) for more details." body += f"\n\n{error.message}" else: body = str(error) body += "\n\n**Traceback:**" lines = "\n".join(traceback.format_exception(error)) body += f"\n\n```{lines}```" save_to_storage(f"error_{datetime.now().timestamp()}.md", body) st.error(f"**{error.__class__.__name__}**\n\n{body}", icon=":material/error:") def _render_message( sender: Sender, message: str | BetaContentBlockParam | ToolResult, ): """Convert input from the user or output from the agent to a streamlit message.""" # streamlit's hotreloading breaks isinstance checks, so we need to check for class names is_tool_result = not isinstance(message, str | dict) if not message or ( is_tool_result and st.session_state.hide_images and not hasattr(message, "error") and not hasattr(message, "output") ): return with st.chat_message(sender): if is_tool_result: message = cast(ToolResult, message) if message.output: if message.__class__.__name__ == "CLIResult": st.code(message.output) else: st.markdown(message.output) if message.error: st.error(message.error) if message.base64_image and not st.session_state.hide_images: st.image(base64.b64decode(message.base64_image)) elif isinstance(message, dict): if message["type"] == "text": st.write(message["text"]) elif message["type"] == "thinking": thinking_content = message.get("thinking", "") st.markdown(f"[Thinking]\n\n{thinking_content}") elif message["type"] == "tool_use": st.code(f'Tool Use: {message["name"]}\nInput: {message["input"]}') else: # only expected return types are text and tool_use raise Exception(f'Unexpected response type {message["type"]}') else: st.markdown(message) if __name__ == "__main__": asyncio.run(main())