databao/executors/react_duckdb/executor.py (62 lines of code) (raw):

import logging from typing import Any import duckdb from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph from sqlalchemy import Connection, Engine from databao.configs.llm import LLMConfig from databao.core import Cache, ExecutionResult, Opa from databao.core.data_source import DBDataSource, DFDataSource, Sources from databao.core.executor import OutputModalityHints from databao.duckdb import register_sqlalchemy from databao.duckdb.react_tools import AgentResponse, execute_duckdb_sql, make_react_duckdb_agent from databao.duckdb.utils import get_db_path from databao.executors.base import GraphExecutor logger = logging.getLogger(__name__) class ReactDuckDBExecutor(GraphExecutor): def __init__(self) -> None: """Initialize agent with lazy graph compilation.""" super().__init__() self._duckdb_connection = duckdb.connect(":memory:") self._compiled_graph: CompiledStateGraph[Any] | None = None def _create_graph(self, data_connection: Any, llm_config: LLMConfig) -> CompiledStateGraph[Any]: """Create and compile the ReAct DuckDB agent graph.""" return make_react_duckdb_agent(data_connection, llm_config.new_chat_model()) def register_db(self, source: DBDataSource) -> None: """Register DB in the DuckDB connection.""" connection = source.db_connection if isinstance(connection, Connection): connection = connection.engine if isinstance(connection, duckdb.DuckDBPyConnection): path = get_db_path(connection) if path is not None: connection.close() self._duckdb_connection.execute(f"ATTACH '{path}' AS {source.name}") else: raise RuntimeError("Memory-based DuckDB is not supported.") elif isinstance(connection, Engine): register_sqlalchemy(self._duckdb_connection, connection, source.name) else: raise ValueError("Only DuckDB or SQLAlchemy connections are supported.") def register_df(self, source: DFDataSource) -> None: self._duckdb_connection.register(source.name, source.df) def execute( self, opas: list[Opa], cache: Cache, llm_config: LLMConfig, sources: Sources, *, rows_limit: int = 100, stream: bool = True, ) -> ExecutionResult: # Get or create graph (cached after first use) compiled_graph = self._compiled_graph or self._create_graph(self._duckdb_connection, llm_config) # Process the opa and get messages messages = self._process_opas(opas, cache) # Execute the graph init_state = {"messages": messages} invoke_config = RunnableConfig(recursion_limit=self._graph_recursion_limit) last_state = self._invoke_graph_sync(compiled_graph, init_state, config=invoke_config, stream=stream) answer: AgentResponse = last_state["structured_response"] logger.info("Generated query: %s", answer.sql) df = execute_duckdb_sql(answer.sql, self._duckdb_connection, limit=rows_limit) # Update message history final_messages = last_state.get("messages", []) self._update_message_history(cache, final_messages) execution_result = ExecutionResult(text=answer.explanation, code=answer.sql, df=df, meta={}) # Set modality hints execution_result.meta[OutputModalityHints.META_KEY] = self._make_output_modality_hints(execution_result) return execution_result