databao/visualizers/vega_chat.py (126 lines of code) (raw):
import io
import json
import logging
from typing import Any
import altair
import pandas as pd
from edaplot.image_utils import vl_to_png_bytes
from edaplot.llms import LLMConfig as VegaLLMConfig
from edaplot.vega import to_altair_chart
from edaplot.vega_chat.vega_chat import MessageInfo, VegaChatConfig, VegaChatGraph, VegaChatState
from langchain_core.runnables import RunnableConfig
from PIL import Image
from databao.configs.llm import LLMConfig
from databao.core import ExecutionResult, VisualisationResult, Visualizer
from databao.executors.base import GraphExecutor
from databao.visualizers.vega_vis_tool import VegaVisTool
logger = logging.getLogger(__name__)
class VegaChatResult(VisualisationResult):
spec: dict[str, Any] | None = None
spec_df: pd.DataFrame | None = None
# TODO expose as part of the VisualisationResult API
def interactive(self) -> VegaVisTool | None:
"""Return an interactive UI wizard for the Vega-Lite chart.
The returned chart object can be rendered in interactive notebooks."""
if self.spec is None or self.spec_df is None:
return None
return VegaVisTool(self.spec, self.spec_df)
def altair(self) -> altair.Chart | None:
"""Return an interactive Altair chart.
The returned chart object can be rendered in interactive notebooks."""
if self.spec is None or self.spec_df is None:
return None
return to_altair_chart(self.spec, self.spec_df)
def image(self) -> Image.Image | None:
"""Return a static PIL.Image.Image."""
if self.spec is None or self.spec_df is None:
return None
if (png_bytes := vl_to_png_bytes(self.spec, self.spec_df)) is not None:
return Image.open(io.BytesIO(png_bytes))
return None
def _convert_llm_config(llm_config: LLMConfig) -> VegaLLMConfig:
# N.B. The two config classes are nearly identical.
return VegaLLMConfig(
name=llm_config.name,
temperature=llm_config.temperature,
max_tokens=llm_config.max_tokens,
reasoning_effort=llm_config.reasoning_effort,
cache_system_prompt=llm_config.cache_system_prompt,
timeout=llm_config.timeout,
api_base_url=llm_config.api_base_url,
use_responses_api=llm_config.use_responses_api,
ollama_pull_model=llm_config.ollama_pull_model,
model_kwargs=llm_config.model_kwargs,
)
class VegaChatVisualizer(Visualizer):
def __init__(self, llm_config: LLMConfig, *, return_interactive_chart: bool = False):
vega_llm = _convert_llm_config(llm_config)
self._vega_config = VegaChatConfig(
llm_config=vega_llm,
data_normalize_column_names=True, # To deal with column names that have special characters
)
self._return_interactive_chart = return_interactive_chart
def _process_result(self, state: VegaChatState, spec_df: pd.DataFrame) -> VegaChatResult:
# Use the possibly transformed dataframe tied to the generated spec
model_out = state["messages"][-1]
text = model_out.message.text()
meta = {"messages": state["messages"]} # Full history. Also used for edit follow ups.
spec = model_out.spec
spec_json = json.dumps(spec, indent=2) if spec is not None else None
if spec is None or not model_out.is_drawable or model_out.is_empty_chart:
return VegaChatResult(
text=f"Failed to visualize request! Output: {text}",
meta=meta,
plot=None,
code=spec_json,
spec=spec,
spec_df=spec_df,
visualizer=self,
)
if not model_out.is_valid_schema and model_out.is_drawable:
# Vega-Lite specs can be invalid (so cannot be used with altair), but they might still be drawable with
# another backend.
logger.warning("Generated Vega-Lite spec is not valid, but it is still drawable: %s", spec_json)
if self._return_interactive_chart:
# The VegaVisTool backend uses vega-embed so it can handle corrupt specs
plot = VegaVisTool(spec, spec_df)
elif (png_bytes := vl_to_png_bytes(spec, spec_df)) is not None:
# Try to convert to an Image that can still be displayed in Jupyter notebooks
plot = Image.open(io.BytesIO(png_bytes))
else:
return VegaChatResult(
text=f"Failed to visualize request! Output: {text}",
meta=meta,
plot=None,
code=spec_json,
spec=spec,
spec_df=spec_df,
visualizer=self,
)
elif self._return_interactive_chart:
plot = VegaVisTool(spec, spec_df)
else:
plot = to_altair_chart(spec, spec_df)
return VegaChatResult(
text=text,
meta=meta,
plot=plot,
code=spec_json,
spec=spec,
spec_df=spec_df,
visualizer=self,
)
def _run_vega_chat(
self, request: str, df: pd.DataFrame, *, messages: list[MessageInfo] | None = None, stream: bool = False
) -> VegaChatResult:
vega_chat = VegaChatGraph(self._vega_config, df=df)
start_state = vega_chat.get_start_state(request, messages=messages)
compiled_graph = vega_chat.compile_graph(is_async=False)
# Use an empty `config` instead of `None` due to a bug in the "AI Agents Debugger" PyCharm plugin.
final_state: VegaChatState = GraphExecutor._invoke_graph_sync(
compiled_graph, start_state, config=RunnableConfig(), stream=stream
)
processed_df = vega_chat.dataframe
return self._process_result(final_state, processed_df)
def visualize(self, request: str | None, data: ExecutionResult, *, stream: bool = False) -> VegaChatResult:
if data.df is None:
return VegaChatResult(text="Nothing to visualize", meta={}, plot=None, code=None, visualizer=self)
if request is None:
# We could also call the ChartRecommender module, but since we want a
# single output plot, we'll just use a simple prompt.
request = "I don't know what the data is about. Show me an interesting plot."
return self._run_vega_chat(request, data.df, stream=stream)
def edit(self, request: str, visualization: VisualisationResult, *, stream: bool = False) -> VegaChatResult:
if not isinstance(visualization, VegaChatResult):
raise ValueError(f"{self.__class__.__name__} can only edit {VegaChatResult.__name__} objects")
if visualization.spec_df is None:
raise ValueError("No dataframe found in the provided visualization")
messages = visualization.meta.get("messages", None)
if messages is None:
raise ValueError("No message history found in the provided visualization")
return self._run_vega_chat(request, visualization.spec_df, messages=messages, stream=stream)