databao/executors/lighthouse/executor.py (109 lines of code) (raw):
from pathlib import Path
from typing import Any
import duckdb
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from sqlalchemy import Connection, Engine
from databao.configs import LLMConfig
from databao.core import Cache, ExecutionResult, Opa
from databao.core.data_source import DBDataSource, DFDataSource, Sources
from databao.core.executor import OutputModalityHints
from databao.duckdb.utils import describe_duckdb_schema, get_db_path, register_sqlalchemy
from databao.executors.base import GraphExecutor
from databao.executors.lighthouse.graph import ExecuteSubmit
from databao.executors.lighthouse.history_cleaning import clean_tool_history
from databao.executors.lighthouse.utils import get_today_date_str, read_prompt_template
class LighthouseExecutor(GraphExecutor):
def __init__(self) -> None:
super().__init__()
self._prompt_template = read_prompt_template(Path("system_prompt.jinja"))
# Create a DuckDB connection for the agent
self._duckdb_connection = duckdb.connect(":memory:")
self._graph: ExecuteSubmit = ExecuteSubmit(self._duckdb_connection)
self._compiled_graph: CompiledStateGraph[Any] | None = None
def render_system_prompt(
self,
data_connection: Any,
sources: Sources,
) -> str:
"""Render system prompt with database schema."""
db_schema = describe_duckdb_schema(data_connection)
context = ""
for db_name, source in sources.dbs.items():
if source.context:
context += f"## Context for DB {db_name}\n\n{source.context}\n\n"
for df_name, source in sources.dfs.items():
if source.context:
context += (
f"## Context for DF {df_name} (fully qualified name 'temp.main.{df_name}')\n\n{source.context}\n\n"
)
for idx, add_ctx in enumerate(sources.additional_context, start=1):
context += f"## General information {idx}\n\n{add_ctx.strip()}\n\n"
context = context.strip()
prompt = self._prompt_template.render(
date=get_today_date_str(), db_schema=db_schema, context=context, tool_limit=self._graph_recursion_limit // 2
)
return prompt.strip()
def register_db(self, source: DBDataSource) -> None:
"""Register DB in the DuckDB connection."""
connection = source.db_connection
if isinstance(connection, Connection):
connection = connection.engine
if isinstance(connection, duckdb.DuckDBPyConnection):
path = get_db_path(connection)
if path is not None:
connection.close()
self._duckdb_connection.execute(f"ATTACH '{path}' AS {source.name} (READ_ONLY)")
else:
raise RuntimeError("Memory-based DuckDB is not supported.")
elif isinstance(connection, Engine):
register_sqlalchemy(self._duckdb_connection, connection, source.name)
else:
raise ValueError("Only DuckDB or SQLAlchemy connections are supported.")
def register_df(self, source: DFDataSource) -> None:
self._duckdb_connection.register(source.name, source.df)
def _get_compiled_graph(self, llm_config: LLMConfig) -> CompiledStateGraph[Any]:
"""Get compiled graph."""
compiled_graph = self._compiled_graph or self._graph.compile(llm_config)
self._compiled_graph = compiled_graph
return compiled_graph
def drop_last_opa_group(self, cache: Cache, n: int = 1) -> None:
"""Drop last n groups of operations from the message history."""
messages = cache.get("state", default={}).get("messages", [])
human_messages = [m for m in messages if isinstance(m, HumanMessage)]
if len(human_messages) < n:
raise ValueError(f"Cannot drop last {n} operations - only {len(human_messages)} operations found.")
c = 0
while c < n:
m = messages.pop()
if isinstance(m, HumanMessage):
c += 1
def execute(
self,
opas: list[Opa],
cache: Cache,
llm_config: LLMConfig,
sources: Sources,
*,
rows_limit: int = 100,
stream: bool = True,
) -> ExecutionResult:
compiled_graph = self._get_compiled_graph(llm_config)
messages: list[BaseMessage] = self._process_opas(opas, cache)
# Prepend system message if not present
all_messages_with_system = messages
if not all_messages_with_system or all_messages_with_system[0].type != "system":
all_messages_with_system = [
SystemMessage(self.render_system_prompt(self._duckdb_connection, sources)),
*all_messages_with_system,
]
cleaned_messages = clean_tool_history(all_messages_with_system, llm_config.max_tokens_before_cleaning)
init_state = self._graph.init_state(cleaned_messages, limit_max_rows=rows_limit)
invoke_config = RunnableConfig(recursion_limit=self._graph_recursion_limit)
last_state = self._invoke_graph_sync(compiled_graph, init_state, config=invoke_config, stream=stream)
execution_result = self._graph.get_result(last_state)
# Update message history (excluding system message which we add dynamically)
final_messages = last_state.get("messages", [])
if final_messages:
new_messages = final_messages[len(cleaned_messages) :]
all_messages = all_messages_with_system + new_messages
all_messages_without_system = [msg for msg in all_messages if msg.type != "system"]
if execution_result.meta.get("messages"):
execution_result.meta["messages"] = all_messages
self._update_message_history(cache, all_messages_without_system)
# Set modality hints
execution_result.meta[OutputModalityHints.META_KEY] = self._make_output_modality_hints(execution_result)
return execution_result