databao/executors/lighthouse/graph.py (285 lines of code) (raw):

from collections.abc import Sequence from typing import Annotated, Any, Literal import pandas as pd from duckdb import DuckDBPyConnection from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import AIMessage, BaseMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool, tool from langchain_openai import ChatOpenAI from langgraph.constants import END, START from langgraph.graph import add_messages from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.prebuilt import InjectedState from typing_extensions import TypedDict from databao.configs.llm import LLMConfig from databao.core import ExecutionResult from databao.duckdb.react_tools import execute_duckdb_sql from databao.executors.frontend.text_frontend import dataframe_to_markdown from databao.executors.lighthouse.utils import exception_to_string class AgentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] query_ids: dict[str, ToolMessage] sql: str | None df: pd.DataFrame | None visualization_prompt: str | None ready_for_user: bool limit_max_rows: int | None def get_query_ids_mapping(messages: list[BaseMessage]) -> dict[str, ToolMessage]: query_ids = {} for message in messages: if isinstance(message, ToolMessage) and isinstance(message.artifact, dict) and "query_id" in message.artifact: query_ids[message.artifact["query_id"]] = message return query_ids class ExecuteSubmit: """Simple graph with two tools: run_sql_query and submit_result. All context must be in the SystemMessage.""" MAX_TOOL_ROWS = 12 """Max number of rows to return in SQL tool calls.""" def __init__(self, connection: DuckDBPyConnection): self._connection = connection def init_state(self, messages: list[BaseMessage], *, limit_max_rows: int | None = None) -> AgentState: return AgentState( messages=messages, query_ids=get_query_ids_mapping(messages), sql=None, df=None, visualization_prompt=None, ready_for_user=False, limit_max_rows=limit_max_rows, ) def get_result(self, state: AgentState) -> ExecutionResult: last_ai_message = None for m in reversed(state["messages"]): if isinstance(m, AIMessage): last_ai_message = m break if last_ai_message is None: raise RuntimeError("No AI message found in message log") if len(last_ai_message.tool_calls) == 0: # Sometimes models don't call the submit_result tool, but we still want to return some dataframe. sql = state.get("sql", "") df = state.get("df") # Latest df result (usually from run_sql_query) visualization_prompt = state.get("visualization_prompt") result = ExecutionResult( text=last_ai_message.text(), df=df, code=sql, meta={ "visualization_prompt": visualization_prompt, "messages": state["messages"], "submit_called": False, }, ) elif len(last_ai_message.tool_calls) > 1: raise RuntimeError("Expected exactly one tool call in AI message") elif last_ai_message.tool_calls[0]["name"] != "submit_result": raise RuntimeError( f"Expected submit_result tool call in AI message, got {last_ai_message.tool_calls[0]['name']}" ) else: sql = state.get("sql", "") df = state.get("df") tool_call = last_ai_message.tool_calls[0] text = tool_call["args"]["result_description"] visualization_prompt = state.get("visualization_prompt", "") result = ExecutionResult( text=text, df=df, code=sql, meta={ "visualization_prompt": visualization_prompt, "messages": state["messages"], "submit_called": True, }, ) return result def make_tools(self) -> list[BaseTool]: @tool(parse_docstring=True) def run_sql_query(sql: str, graph_state: Annotated[AgentState, InjectedState]) -> dict[str, Any]: """ Run a SELECT SQL query in the database. Returns the first 12 rows in csv format. Args: sql: SQL query """ try: # TODO use ToolRuntime in LangChain v1.0 limit = graph_state["limit_max_rows"] df = execute_duckdb_sql(sql, self._connection, limit=limit) df_csv = df.head(self.MAX_TOOL_ROWS).to_csv(index=False) df_markdown = dataframe_to_markdown(df.head(self.MAX_TOOL_ROWS), index=False) if len(df) > self.MAX_TOOL_ROWS: df_csv += f"\nResult is truncated from {len(df)} to {self.MAX_TOOL_ROWS} rows." df_markdown += f"\nResult is truncated from {len(df)} to {self.MAX_TOOL_ROWS} rows." return {"df": df, "sql": sql, "csv": df_csv, "markdown": df_markdown} except Exception as e: return {"error": exception_to_string(e)} @tool(parse_docstring=True) def submit_result( query_id: str, result_description: str, visualization_prompt: str, ) -> str: """ Call this tool with the ID of the query you want to submit to the user. This will return control to the user and must always be the last tool call. The user will see the full query result, not just the first 12 rows. Returns a confirmation message. Args: query_id: The ID of the query to submit (query_ids are automatically generated when you run queries). result_description: A comment to a final result. This will be included in the final result. visualization_prompt: Optional visualization prompt. If not empty, a Vega-Lite visualization agent will be asked to plot the submitted query data according to instructions in the prompt. The instructions should be short and simple. """ return f"Query {query_id} submitted successfully. Your response is now visible to the user." tools = [run_sql_query, submit_result] return tools def compile(self, model_config: LLMConfig) -> CompiledStateGraph[Any]: tools = self.make_tools() llm_model = model_config.new_chat_model() model_with_tools = self._model_bind_tools(llm_model, tools) def llm_node(state: AgentState) -> dict[str, Any]: messages = state["messages"] response = self._chat(messages, model_config, model_with_tools) return {"messages": [response[-1]]} def tool_executor_node(state: AgentState) -> dict[str, Any]: last_message = state["messages"][-1] tool_messages = [] assert isinstance(last_message, AIMessage) tool_calls = last_message.tool_calls is_ready_for_user = any(tc["name"] == "submit_result" for tc in tool_calls) if is_ready_for_user: if len(tool_calls) > 1: tool_messages = [ ToolMessage("submit_result must be the only tool call.", tool_call_id=tool_call["id"]) for tool_call in tool_calls ] return {"messages": tool_messages, "ready_for_user": False} else: tool_call = tool_calls[0] if "query_ids" not in state or len(state["query_ids"]) == 0: tool_messages = [ ToolMessage("No queries have been executed yet.", tool_call_id=tool_call["id"]) ] return {"messages": tool_messages, "ready_for_user": False} query_id = tool_call["args"]["query_id"] if query_id not in state["query_ids"]: available_ids = ", ".join(state["query_ids"].keys()) tool_messages = [ ToolMessage( f"Query ID {query_id} not found. Available query IDs: {available_ids}", tool_call_id=tool_call["id"], ) ] return {"messages": tool_messages, "ready_for_user": False} target_tool_message = state["query_ids"][query_id] if target_tool_message.artifact is None or "df" not in target_tool_message.artifact: tool_messages = [ ToolMessage(f"Query {query_id} does not have a valid result.", tool_call_id=tool_call["id"]) ] return {"messages": tool_messages, "ready_for_user": False} query_ids = dict(state.get("query_ids", {})) sql = state.get("sql") df = state.get("df") visualization_prompt = state.get("visualization_prompt", "") message_index = len(state["messages"]) - 1 for idx, tool_call in enumerate(tool_calls): name = tool_call["name"] args = tool_call["args"] tool_call_id = tool_call["id"] # Find the tool by name tool = next((t for t in tools if t.name == name), None) if tool is None: tool_messages.append(ToolMessage(content=f"Tool {name} does not exist!", tool_call_id=tool_call_id)) continue try: result = tool.invoke(args | {"graph_state": state}) except Exception as e: result = {"error": exception_to_string(e) + f"\nTool: {name}, Args: {args}"} content = "" if name == "run_sql_query": sql = result.get("sql") df = result.get("df") # Generate query_id using message index and tool call index query_id = f"{message_index}-{idx}" # Override the query_id in the result result["query_id"] = query_id content = result.get("csv", result.get("error", "")) if "csv" in result: content = f"query_id='{query_id}'\n\n{content}" if query_id: query_ids[query_id] = ToolMessage( content=content, tool_call_id=tool_call_id, artifact=result, ) elif name == "submit_result": content = str(result) query_id = tool_call["args"]["query_id"] visualization_prompt = tool_call["args"].get("visualization_prompt", "") sql = state["query_ids"][query_id].artifact["sql"] df = state["query_ids"][query_id].artifact["df"] tool_messages.append(ToolMessage(content=content, tool_call_id=tool_call_id, artifact=result)) if name == "submit_result": return { "messages": tool_messages, "sql": sql, "df": df, "visualization_prompt": visualization_prompt, "ready_for_user": True, } return { "messages": tool_messages, "query_ids": query_ids, "sql": sql, "df": df, "visualization_prompt": visualization_prompt, "ready_for_user": False, } def should_continue(state: AgentState) -> Literal["tool_executor", "end"]: # Check if there are tool calls in the last message last_message = state["messages"][-1] if isinstance(last_message, AIMessage) and last_message.tool_calls: return "tool_executor" return "end" def should_finish(state: AgentState) -> Literal["llm_node", "end"]: # Check if we just executed submit_result - if so, end the conversation if state.get("ready_for_user", False): return "end" return "llm_node" graph = StateGraph(AgentState) graph.add_node("llm_node", llm_node) graph.add_node("tool_executor", tool_executor_node) graph.add_edge(START, "llm_node") graph.add_conditional_edges("llm_node", should_continue, {"tool_executor": "tool_executor", "end": END}) graph.add_conditional_edges("tool_executor", should_finish, {"llm_node": "llm_node", "end": END}) return graph.compile() @staticmethod def _model_bind_tools( model: BaseChatModel, tools: Sequence[BaseTool], **kwargs: Any ) -> Runnable[LanguageModelInput, BaseMessage]: if isinstance(model, ChatOpenAI): return model.bind_tools(tools, strict=True, **kwargs) else: return model.bind_tools(tools, **kwargs) @staticmethod def _chat( messages: list[BaseMessage], config: LLMConfig, model: Runnable[list[BaseMessage], Any] | None = None, ) -> list[BaseMessage]: if model is None: model = config.new_chat_model() messages = ExecuteSubmit._apply_system_prompt_caching(config, messages) response: AIMessage = ExecuteSubmit._call_model(model, messages) return [*messages, response] @staticmethod def _is_anthropic_model(config: LLMConfig) -> bool: """Check if the model is an Anthropic model based on the config name.""" return "claude" in config.name.lower() @staticmethod def _apply_system_prompt_caching(config: LLMConfig, messages: list[BaseMessage]) -> list[BaseMessage]: """Apply system prompt caching for Anthropic models.""" if not (config.cache_system_prompt and ExecuteSubmit._is_anthropic_model(config)): return messages # Assume only the first message can be a system prompt. assert all(m.type != "system" for m in messages[1:]) if messages[0].type == "system": messages = [ExecuteSubmit._set_message_cache_breakpoint(config, messages[0]), *messages[1:]] return messages @staticmethod def _set_message_cache_breakpoint(config: LLMConfig, message: BaseMessage) -> BaseMessage: """Enable prompt caching for this message (for Anthropic models). If you have a list of messages, set a breakpoint only on the last message to automatically cache all previous messages. See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching > Prompt caching references the entire prompt - tools, system, and messages (in that order) up to and including the block designated with cache_control. """ if not ExecuteSubmit._is_anthropic_model(config): return message new_content: list[dict[str, Any] | str] match message.content: case str() | dict(): new_content = [ExecuteSubmit._set_anthropic_cache_breakpoint(message.content)] case list(): # Set checkpoint only for the last message new_content = message.content.copy() new_content[-1] = ExecuteSubmit._set_anthropic_cache_breakpoint(new_content[-1]) return message.model_copy(update={"content": new_content}) @staticmethod def _set_anthropic_cache_breakpoint(content: str | dict[str, Any]) -> dict[str, Any]: if isinstance(content, str): return {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} elif isinstance(content, dict): d = content.copy() d["cache_control"] = {"type": "ephemeral"} return d else: raise ValueError(f"Unknown content type: {type(content)}") @staticmethod def _call_model(model: Runnable[list[BaseMessage], Any], messages: list[BaseMessage]) -> Any: return model.with_retry(wait_exponential_jitter=True, stop_after_attempt=3).invoke(messages)