databao/core/agent.py (121 lines of code) (raw):
from pathlib import Path
from typing import TYPE_CHECKING
from duckdb import DuckDBPyConnection
from langchain_core.language_models.chat_models import BaseChatModel
from pandas import DataFrame
from sqlalchemy import Connection, Engine
from databao.core.data_source import DBDataSource, DFDataSource, Sources
from databao.core.thread import Thread
if TYPE_CHECKING:
from databao.configs.llm import LLMConfig
from databao.core.cache import Cache
from databao.core.executor import Executor
from databao.core.visualizer import Visualizer
class Agent:
"""An agent manages all databases and Dataframes as well as the context for them.
Agent determines what LLM to use, what executor to use and how to visualize data for all threads.
Several threads can be spawned out of the agent.
"""
def __init__(
self,
llm: "LLMConfig",
data_executor: "Executor",
visualizer: "Visualizer",
cache: "Cache",
*,
name: str = "default_agent",
rows_limit: int,
stream_ask: bool = True,
stream_plot: bool = False,
lazy_threads: bool = False,
auto_output_modality: bool = True,
):
self.__name = name
self.__llm = llm.new_chat_model()
self.__llm_config = llm
self.__sources: Sources = Sources(dfs={}, dbs={}, additional_context=[])
self.__executor = data_executor
self.__visualizer = visualizer
self.__cache = cache
# Thread defaults
self.__rows_limit = rows_limit
self.__lazy_threads = lazy_threads
self.__auto_output_modality = auto_output_modality
self.__stream_ask = stream_ask
self.__stream_plot = stream_plot
def _parse_context_arg(self, context: str | Path | None) -> str | None:
if context is None:
return None
if isinstance(context, Path):
return context.read_text()
return context
def add_db(
self,
connection: DuckDBPyConnection | Engine | Connection,
*,
name: str | None = None,
context: str | Path | None = None,
) -> None:
"""
Add a database connection to the internal collection and optionally associate it
with a specific context for query execution. Supports integration with SQLAlchemy
engines and direct DuckDB connections.
Args:
connection (DuckDBPyConnection | Engine | Connection): The database connection to be added.
Can be an SQLAlchemy engine or connection or a native DuckDB connection.
name (str | None): Optional name to assign to the database connection. If
not provided, a default name such as 'db1', 'db2', etc., will be
generated dynamically based on the collection size.
context (str | Path | None): Optional context for the database connection. It can
be either the path to a file whose content will be used as the context or
the direct context as a string.
"""
if not isinstance(connection, (DuckDBPyConnection, Engine, Connection)):
raise ValueError("Connection must be a DuckDB connection or SQLAlchemy engine.")
conn_name = name or f"db{len(self.__sources.dbs) + 1}"
context_text = self._parse_context_arg(context) or ""
source = DBDataSource(name=conn_name, context=context_text, db_connection=connection)
self.__sources.dbs[conn_name] = source
self.executor.register_db(source)
def add_df(self, df: DataFrame, *, name: str | None = None, context: str | Path | None = None) -> None:
"""Register a DataFrame in this agent and in the agent's DuckDB.
Args:
df: DataFrame to expose to executors/executors/SQL.
name: Optional name; defaults to df1/df2/...
context: Optional text or path to a file describing this dataset for the LLM.
"""
df_name = name or f"df{len(self.__sources.dfs) + 1}"
context_text = self._parse_context_arg(context) or ""
source = DFDataSource(name=df_name, context=context_text, df=df)
self.__sources.dfs[df_name] = source
self.executor.register_df(source)
def add_context(self, context: str | Path) -> None:
"""Add additional context to help models understand your data.
Use this method to add general information that might not be associated with a specific data source.
If the information is specific to a data source, use the `context` argument of `add_db` and `add_df`.
Args:
context: The string or the path to a file containing the additional context.
"""
text = self._parse_context_arg(context)
if text is None:
raise ValueError("Invalid context provided.")
self.__sources.additional_context.append(text)
def thread(
self,
*,
stream_ask: bool | None = None,
stream_plot: bool | None = None,
lazy: bool | None = None,
auto_output_modality: bool | None = None,
) -> Thread:
"""Start a new thread in this agent."""
if not self.__sources.dbs and not self.__sources.dfs:
raise ValueError("No databases or dataframes registered in this agent.")
return Thread(
self,
rows_limit=self.__rows_limit,
stream_ask=stream_ask if stream_ask is not None else self.__stream_ask,
stream_plot=stream_plot if stream_plot is not None else self.__stream_plot,
lazy=lazy if lazy is not None else self.__lazy_threads,
auto_output_modality=auto_output_modality
if auto_output_modality is not None
else self.__auto_output_modality,
)
@property
def sources(self) -> Sources:
return self.__sources
@property
def dbs(self) -> dict[str, DBDataSource]:
return dict(self.__sources.dbs)
@property
def dfs(self) -> dict[str, DFDataSource]:
return dict(self.__sources.dfs)
@property
def name(self) -> str:
return self.__name
@property
def llm(self) -> BaseChatModel:
return self.__llm
@property
def llm_config(self) -> "LLMConfig":
return self.__llm_config
@property
def executor(self) -> "Executor":
return self.__executor
@property
def visualizer(self) -> "Visualizer":
return self.__visualizer
@property
def cache(self) -> "Cache":
return self.__cache
@property
def additional_context(self) -> list[str]:
"""General additional context not specific to any one data source."""
return self.__sources.additional_context