databao/executors/frontend/text_frontend.py (104 lines of code) (raw):
import re
from typing import Any, TextIO
import pandas as pd
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ToolMessage
from databao.executors.frontend.messages import get_reasoning_content, get_tool_call, get_tool_call_sql
class TextStreamFrontend:
"""Helper for streaming LangGraph LLM outputs to a text stream (stdout, stderr, a file, etc.)."""
def __init__(
self,
start_state: dict[str, Any],
*,
writer: TextIO | None = None,
escape_markdown: bool = False,
show_headers: bool = True,
pretty_sql: bool = True,
):
self._writer = writer # Use io.Writer type in Python 3.14
self._escape_markdown = escape_markdown
self._show_headers = show_headers
self._message_count = len(start_state.get("messages", []))
self._started = False
self._is_tool_calling = False
self._pretty_sql = pretty_sql
def write(self, text: str) -> None:
if not self._started:
self.start()
print(text, end="", flush=True, file=self._writer)
def write_dataframe(self, df: pd.DataFrame, *, name: str | None = None, max_rows: int = 10) -> None:
rows_to_show = min(max_rows, len(df))
self.write(f"[df: name={name or ''}, showing {rows_to_show} / {len(df)} rows]\n")
df_str = dataframe_to_markdown(df.head(rows_to_show))
self.write(f"{df_str}\n\n")
def write_message_chunk(self, chunk: BaseMessageChunk) -> None:
if not isinstance(chunk, AIMessageChunk):
return # Handle ToolMessage results in add_state_chunk
reasoning_text = get_reasoning_content(chunk)
text = reasoning_text + chunk.text
if self._escape_markdown:
text = escape_markdown_text(text)
self.write(text)
if len(chunk.tool_call_chunks) > 0:
# N.B. LangChain sometimes waits for the whole string to complete before yielding chunks
# That's why long "sql" tool calls take some time to show up and then the whole sql is shown in a batch
if not self._is_tool_calling:
self.write("\n\n")
for tool_call_chunk in chunk.tool_call_chunks:
self.write(f"[tool_call: '{tool_call_chunk['name']}']\n")
self.write("```\n") # Open code block
self._is_tool_calling = True
for tool_call_chunk in chunk.tool_call_chunks:
if tool_call_chunk["args"] is not None:
self.write(tool_call_chunk["args"])
elif self._is_tool_calling:
self.write("\n```\n\n") # Close code block
self._is_tool_calling = False
def write_state_chunk(self, state_chunk: dict[str, Any]) -> None:
"""The state chunk is assumed to contain a "messages" key."""
if self._is_tool_calling:
self.write("\n```\n\n") # Close code block
self._is_tool_calling = False
# Loop through new messages only.
# We could either force the caller of the frontend to provide new messages only,
# but for ease of use we assume the state contains a list of messages and do it here.
messages: list[BaseMessage] = state_chunk.get("messages", [])
new_messages = messages[self._message_count :]
self._message_count += len(new_messages)
for message in new_messages:
if isinstance(message, ToolMessage):
tool_call = get_tool_call(messages, message)
tool_name = tool_call["name"] if tool_call is not None else "unknown"
self.write(f"\n[tool_call_output: '{tool_name}']")
self.write(f"\n```\n{message.text.strip()}\n```\n\n")
if message.artifact is not None and isinstance(message.artifact, dict):
for art_name, art_value in message.artifact.items():
if isinstance(art_value, pd.DataFrame):
self.write_dataframe(art_value, name=art_name)
elif self._pretty_sql and isinstance(message, AIMessage):
# During tool calling we show raw JSON chunks, but for SQL we also want pretty formatting.
for tool_call in message.tool_calls:
sql = get_tool_call_sql(tool_call)
if sql is not None:
self.write(f"\n```sql\n{sql.strip()}\n```\n\n")
def write_stream_chunk(self, mode: str, chunk: Any) -> None:
if mode == "messages":
token_chunk, _token_metadata = chunk
self.write_message_chunk(token_chunk)
elif mode == "values":
if isinstance(chunk, dict):
self.write_state_chunk(chunk)
else:
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
def start(self) -> None:
self._started = True
if self._show_headers:
self.write("=" * 8 + " <THINKING> " + "=" * 8 + "\n\n")
def end(self) -> None:
if self._show_headers:
self.write("\n" + "=" * 8 + " </THINKING> " + "=" * 8 + "\n\n")
self._started = False
def escape_currency_dollar_signs(text: str) -> str:
"""Escapes dollar signs in a string to prevent MathJax interpretation in markdown environments."""
return re.sub(r"\$(\d+)", r"\$\1", text)
def escape_strikethrough(text: str) -> str:
"""Prevents aggressive markdown strikethrough formatting."""
return re.sub(r"~(.?\d+)", r"\~\1", text)
def escape_markdown_text(text: str) -> str:
text = escape_strikethrough(text)
text = escape_currency_dollar_signs(text)
return text
def dataframe_to_markdown(df: pd.DataFrame, *, index: bool = False) -> str:
try:
# to_markdown doesn't work with all types: https://github.com/pandas-dev/pandas/issues/50866
return df.to_markdown(index=index)
except Exception:
return df.to_string(index=index)