databao/executors/base.py (88 lines of code) (raw):
from abc import ABC
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from databao.core import Cache
from databao.core.executor import ExecutionResult, Executor, OutputModalityHints
from databao.core.opa import Opa
from databao.executors.frontend.text_frontend import TextStreamFrontend
try:
from duckdb import DuckDBPyConnection
except ImportError:
DuckDBPyConnection = Any # type: ignore
class GraphExecutor(Executor, ABC):
"""
Base class for LangGraph executors that execute with a DuckDB connection and LLM configuration.
Provides common functionality for graph caching, message handling, and OPA processing.
"""
def __init__(self) -> None:
"""Initialize agent with graph caching infrastructure."""
self._graph_recursion_limit = 50
def _process_opas(self, opas: list[Opa], cache: Cache) -> list[Any]:
"""
Process a single opa and convert it to a message, appending to message history.
Returns:
All messages including the new one
"""
messages: list[Any] = cache.get("state", {}).get("messages", [])
query = "\n\n".join(opa.query for opa in opas)
messages.append(HumanMessage(content=query))
return messages
def _update_message_history(self, cache: Cache, final_messages: list[Any]) -> None:
"""Update message history in cache with final messages from graph execution."""
if final_messages:
cache.put("state", {"messages": final_messages})
def _make_output_modality_hints(self, result: ExecutionResult) -> OutputModalityHints:
# A separate LLM module could be used to fill out the hints
vis_prompt = result.meta.get("visualization_prompt", None)
if vis_prompt is not None and len(vis_prompt) == 0:
vis_prompt = None
df = result.df
should_visualize = vis_prompt is not None and df is not None and len(df) >= 3
return OutputModalityHints(visualization_prompt=vis_prompt, should_visualize=should_visualize)
@staticmethod
def _invoke_graph_sync(
compiled_graph: CompiledStateGraph[Any],
start_state: Any,
*,
config: RunnableConfig | None = None,
stream: bool = True,
**kwargs: Any,
) -> Any:
"""Invoke the graph with the given start state and return the output state."""
if stream:
return GraphExecutor._execute_stream_sync(compiled_graph, start_state, config=config, **kwargs)
else:
return compiled_graph.invoke(start_state, config=config)
@staticmethod
async def _execute_stream(
compiled_graph: CompiledStateGraph[Any],
start_state: Any,
*,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
writer = TextStreamFrontend(start_state)
last_state = None
async for mode, chunk in compiled_graph.astream(
start_state,
stream_mode=["values", "messages"],
config=config,
**kwargs,
):
writer.write_stream_chunk(mode, chunk)
if mode == "values":
last_state = chunk
writer.end()
assert last_state is not None
return last_state
@staticmethod
def _execute_stream_sync(
compiled_graph: CompiledStateGraph[Any],
start_state: Any,
*,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
writer = TextStreamFrontend(start_state)
last_state = None
for mode, chunk in compiled_graph.stream(
start_state,
stream_mode=["values", "messages"],
config=config,
**kwargs,
):
writer.write_stream_chunk(mode, chunk)
if mode == "values":
last_state = chunk
writer.end()
assert last_state is not None
return last_state