databao/executors/frontend/text_frontend.py (104 lines of code) (raw):

import re from typing import Any, TextIO import pandas as pd from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ToolMessage from databao.executors.frontend.messages import get_reasoning_content, get_tool_call, get_tool_call_sql class TextStreamFrontend: """Helper for streaming LangGraph LLM outputs to a text stream (stdout, stderr, a file, etc.).""" def __init__( self, start_state: dict[str, Any], *, writer: TextIO | None = None, escape_markdown: bool = False, show_headers: bool = True, pretty_sql: bool = True, ): self._writer = writer # Use io.Writer type in Python 3.14 self._escape_markdown = escape_markdown self._show_headers = show_headers self._message_count = len(start_state.get("messages", [])) self._started = False self._is_tool_calling = False self._pretty_sql = pretty_sql def write(self, text: str) -> None: if not self._started: self.start() print(text, end="", flush=True, file=self._writer) def write_dataframe(self, df: pd.DataFrame, *, name: str | None = None, max_rows: int = 10) -> None: rows_to_show = min(max_rows, len(df)) self.write(f"[df: name={name or ''}, showing {rows_to_show} / {len(df)} rows]\n") df_str = dataframe_to_markdown(df.head(rows_to_show)) self.write(f"{df_str}\n\n") def write_message_chunk(self, chunk: BaseMessageChunk) -> None: if not isinstance(chunk, AIMessageChunk): return # Handle ToolMessage results in add_state_chunk reasoning_text = get_reasoning_content(chunk) text = reasoning_text + chunk.text if self._escape_markdown: text = escape_markdown_text(text) self.write(text) if len(chunk.tool_call_chunks) > 0: # N.B. LangChain sometimes waits for the whole string to complete before yielding chunks # That's why long "sql" tool calls take some time to show up and then the whole sql is shown in a batch if not self._is_tool_calling: self.write("\n\n") for tool_call_chunk in chunk.tool_call_chunks: self.write(f"[tool_call: '{tool_call_chunk['name']}']\n") self.write("```\n") # Open code block self._is_tool_calling = True for tool_call_chunk in chunk.tool_call_chunks: if tool_call_chunk["args"] is not None: self.write(tool_call_chunk["args"]) elif self._is_tool_calling: self.write("\n```\n\n") # Close code block self._is_tool_calling = False def write_state_chunk(self, state_chunk: dict[str, Any]) -> None: """The state chunk is assumed to contain a "messages" key.""" if self._is_tool_calling: self.write("\n```\n\n") # Close code block self._is_tool_calling = False # Loop through new messages only. # We could either force the caller of the frontend to provide new messages only, # but for ease of use we assume the state contains a list of messages and do it here. messages: list[BaseMessage] = state_chunk.get("messages", []) new_messages = messages[self._message_count :] self._message_count += len(new_messages) for message in new_messages: if isinstance(message, ToolMessage): tool_call = get_tool_call(messages, message) tool_name = tool_call["name"] if tool_call is not None else "unknown" self.write(f"\n[tool_call_output: '{tool_name}']") self.write(f"\n```\n{message.text.strip()}\n```\n\n") if message.artifact is not None and isinstance(message.artifact, dict): for art_name, art_value in message.artifact.items(): if isinstance(art_value, pd.DataFrame): self.write_dataframe(art_value, name=art_name) elif self._pretty_sql and isinstance(message, AIMessage): # During tool calling we show raw JSON chunks, but for SQL we also want pretty formatting. for tool_call in message.tool_calls: sql = get_tool_call_sql(tool_call) if sql is not None: self.write(f"\n```sql\n{sql.strip()}\n```\n\n") def write_stream_chunk(self, mode: str, chunk: Any) -> None: if mode == "messages": token_chunk, _token_metadata = chunk self.write_message_chunk(token_chunk) elif mode == "values": if isinstance(chunk, dict): self.write_state_chunk(chunk) else: raise ValueError(f"Unexpected chunk type: {type(chunk)}") def start(self) -> None: self._started = True if self._show_headers: self.write("=" * 8 + " <THINKING> " + "=" * 8 + "\n\n") def end(self) -> None: if self._show_headers: self.write("\n" + "=" * 8 + " </THINKING> " + "=" * 8 + "\n\n") self._started = False def escape_currency_dollar_signs(text: str) -> str: """Escapes dollar signs in a string to prevent MathJax interpretation in markdown environments.""" return re.sub(r"\$(\d+)", r"\$\1", text) def escape_strikethrough(text: str) -> str: """Prevents aggressive markdown strikethrough formatting.""" return re.sub(r"~(.?\d+)", r"\~\1", text) def escape_markdown_text(text: str) -> str: text = escape_strikethrough(text) text = escape_currency_dollar_signs(text) return text def dataframe_to_markdown(df: pd.DataFrame, *, index: bool = False) -> str: try: # to_markdown doesn't work with all types: https://github.com/pandas-dev/pandas/issues/50866 return df.to_markdown(index=index) except Exception: return df.to_string(index=index)