databao/core/thread.py (152 lines of code) (raw):

import uuid from typing import TYPE_CHECKING, Any from pandas import DataFrame from typing_extensions import Self from databao.core.executor import ExecutionResult, OutputModalityHints from databao.core.opa import Opa if TYPE_CHECKING: from databao.core.agent import Agent from databao.core.visualizer import VisualisationResult class Thread: """A single conversational thread within an agent. - Maintains its own message history (isolated from other threads). - Materializes data and visualizations eagerly or lazily and caches results per thread. - Exposes helpers to get the latest dataframe/text/plot/code. """ def __init__( self, agent: "Agent", *, rows_limit: int = 1000, stream_ask: bool = True, stream_plot: bool = False, lazy: bool = False, auto_output_modality: bool = True, ): self._agent = agent self._default_rows_limit = rows_limit self._lazy_mode = lazy self._auto_output_modality = auto_output_modality """Automatically detect the appropriate modality to output based on the user's input. If False, you must manually call the appropriate ask/plot method. This allows .ask to be used for plotting, i.e. `ask("show a bar chart")` will result in a plot being generated. """ self._stream_ask: bool | None = None self._stream_plot: bool | None = None self._default_stream_ask: bool = stream_ask self._default_stream_plot: bool = stream_plot self._data_materialized_rows: int | None = None self._data_result: ExecutionResult | None = None self._visualization_result: VisualisationResult | None = None self._visualization_request: str | None = None self._opas_processed_count: int = 0 self._opas: list[list[Opa]] = [] """Opas are grouped. Each group is processed independently.""" self._meta: dict[str, Any] = {} # A unique cache scope so executors can store per-thread state (e.g., message history) self._cache_scope = f"{self._agent.name}/{uuid.uuid4()}" def _materialize_data(self, rows_limit: int | None) -> "ExecutionResult": """Materialize the latest data state by executing pending OPAs if needed.""" new_opas = self._opas[self._opas_processed_count :] if len(new_opas) > 0: rows_limit = rows_limit if rows_limit else self._default_rows_limit stream = self._stream_ask if self._stream_ask is not None else self._default_stream_ask for opa in new_opas: self._data_result = self._agent.executor.execute( opa, cache=self._agent.cache.scoped(self._cache_scope), llm_config=self._agent.llm_config, sources=self._agent.sources, rows_limit=rows_limit, stream=stream, ) self._meta.update(self._data_result.meta) self._opas_processed_count += len(new_opas) self._data_materialized_rows = rows_limit if self._data_result is None: raise RuntimeError("_data_result is None after materialization") return self._data_result def _materialize_visualization(self, request: str | None, rows_limit: int | None) -> "VisualisationResult": """Materialize latest visualization for the given request and current data.""" data = self._materialize_data(rows_limit) if self._visualization_result is None or request != self._visualization_request: # TODO Cache visualization results as in Executor.execute()? stream = self._stream_plot if self._stream_plot is not None else self._default_stream_plot self._visualization_result = self._agent.visualizer.visualize(request, data, stream=stream) self._visualization_request = request self._meta.update(self._visualization_result.meta) self._meta["plot_code"] = self._visualization_result.code # maybe worth to expand as a property later if self._visualization_result is None: raise RuntimeError("_visualization_result is None after materialization") return self._visualization_result def _materialize(self, rows_limit: int | None) -> None: data_result = self._materialize_data(rows_limit) if not self._auto_output_modality: return # The Executor can provide output modality hints hints = data_result.meta.get(OutputModalityHints.META_KEY, OutputModalityHints()) if not hints.should_visualize: return # Let the Visualizer recommend a plot based on the df if no prompt is provided (None) self.plot(hints.visualization_prompt) def text(self) -> str: """Return the latest textual answer from the executor/LLM.""" return self._materialize_data(self._data_materialized_rows).text def code(self) -> str | None: """Return the latest generated code.""" return self._materialize_data(self._data_materialized_rows).code def meta(self) -> dict[str, Any]: """Aggregated metadata from executor/visualizer for this thread.""" self._materialize_data(self._data_materialized_rows) return self._meta def df(self, *, rows_limit: int | None = None) -> DataFrame | None: """Return the latest dataframe, materializing data as needed. Args: rows_limit: Optional override for the number of rows to materialize in lazy mode. """ df = self._materialize_data(rows_limit if rows_limit else self._data_materialized_rows).df # Copy the dataframe to avoid state mutation from outside return df.copy() if df is not None else None def plot( self, request: str | None = None, *, rows_limit: int | None = None, stream: bool | None = None ) -> "VisualisationResult": """Generate or return the latest visualization for the current data. Args: request: Optional natural-language plotting request. rows_limit: Optional row limit for data materialization in lazy mode. """ self._stream_plot = stream return self._materialize_visualization(request, rows_limit if rows_limit else self._data_materialized_rows) def ask(self, query: str, *, rows_limit: int | None = None, stream: bool | None = None) -> Self: """Append a new user query to this thread. Returns self to allow chaining (e.g., thread.ask("...")). Setting rows_limit has no effect in lazy mode. """ # NB. A new Opa is created even if it's identical to the previous one. if self._opas_processed_count < len(self._opas): assert self._lazy_mode self._opas[-1].append(Opa(query=query)) else: # Add new Opa group self._opas.append([Opa(query=query)]) # Invalidate old results so they are not used by repr methods self._data_result = None self._visualization_result = None # If multiple .asks are chained, the last setting takes precedence. # Tracking the stream setting for each ask in a chain would not work with "opa-collocation". self._stream_ask = stream if not self._lazy_mode: self._materialize(rows_limit) return self def drop(self, n: int = 1) -> None: """Remove N last user queries from this thread along with the answer it produced.""" sum_, n_groups = 0, 0 for group in reversed(self._opas): sum_ += len(group) n_groups += 1 if sum_ >= n: break n_materialized_group = n_groups - (len(self._opas) - self._opas_processed_count) # We need to drop `n` individual opas, combined into `n_groups` groups, # `n_materialized_group` of which are materialized. if sum_ == n: # Full drop of groups self._opas = self._opas[:-n_groups] else: full_groups = n_groups - 1 if full_groups > 0: self._opas = self._opas[:-full_groups] self._opas[-1] = self._opas[-1][: -(sum_ - n)] self._agent.executor.drop_last_opa_group(self._agent.cache.scoped(self._cache_scope), n=n_materialized_group) self._opas_processed_count -= n_materialized_group if self._opas: print( f"Dropped last {n} operation{'s' if n > 1 else ''}. Last remaining operation:" f"\n{self._opas[-1][-1].query}" ) else: print("Dropped all operations.") def __str__(self) -> str: if self._data_result is not None: bundle = self._data_result._repr_mimebundle_() if bundle is not None: if (text_markdown := bundle.get("text/markdown")) is not None: return text_markdown # type: ignore[no-any-return] elif (text_plain := bundle.get("text/plain")) is not None: return text_plain # type: ignore[no-any-return] return repr(self) def __repr__(self) -> str: if self._data_result is not None: return ( f"Materialized {self.__class__.__name__} with " f"{len(self._data_result.df) if self._data_result.df is not None else 0} data rows." ) else: return f"Unmaterialized {self.__class__.__name__}." def _repr_mimebundle_(self, include: Any = None, exclude: Any = None) -> dict[str, Any] | None: """Return MIME bundle for rendering in notebooks. No materialization is performed in this method. If using lazy mode, you must trigger materialization manually. """ # See docs for the behavior of magic methods https://ipython.readthedocs.io/en/stable/config/integrating.html#custom-methods # If None is returned, IPython will fall back to repr() if self._data_result is None: return None modality_hints = self._data_result.meta.get(OutputModalityHints.META_KEY, OutputModalityHints()) plot_bundle: dict[str, Any] | None = None if modality_hints.should_visualize and self._visualization_result is not None: plot_bundle = self._visualization_result._repr_mimebundle_(include, exclude) bundle = self._data_result._repr_mimebundle_(include, exclude, plot_mimebundle=plot_bundle) return bundle