databao/executors/lighthouse/history_cleaning.py (54 lines of code) (raw):

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolCall, ToolMessage from langchain_core.messages.utils import count_tokens_approximately def _truncate_no_df_block(messages: list[BaseMessage]) -> AIMessage: """Returns one AIMessage with only the last message.""" assert messages[-1].type == "ai" text = f"""Message history was truncated. {len(messages) - 1} messages were deleted. Here is an answer, which was shown to the user: """ + str(messages[-1].content) return AIMessage(content=text) def _truncate_block(dfs: dict[str, dict[str, str]], messages: list[BaseMessage]) -> AIMessage: """Returns one AIMessage with a summary of the block.""" assert isinstance(messages[-1], ToolMessage) assert isinstance(messages[-2], AIMessage) df = None for d in dfs.values(): if d.get("query_id") == messages[-2].tool_calls[0]["args"].get("query_id"): df = d break assert df is not None tool_call: ToolCall = messages[-2].tool_calls[0] text = f"""Message history was truncated. {len(messages) - 1} messages were deleted. This SQL was generated: ```sql {df["sql"]} ``` Here is an answer, which was shown to the user: Dataframe: {df["df"]} Text: {tool_call["args"]["result_description"]} """ if tool_call["args"].get("visualization_prompt"): text += f"\n\nVisualization prompt: {tool_call['args']['visualization_prompt']}" return AIMessage(content=text) def clean_tool_history(messages: list[BaseMessage], token_limit: int) -> list[BaseMessage]: """ If message history exceeds token limit, truncates it. It removes all intermediate messages and changes a final AI message. The final message contains SQL, dataframe and text. Specific for AgentState and ExecuteSubmit graph. Returns: messages ready to be sent to LLM. """ if count_tokens_approximately(messages) < token_limit: return messages.copy() assert isinstance(messages[-1], HumanMessage) dfs: dict[str, dict[str, str]] = {} buffer = [] result: list[BaseMessage] = [] for i in range(len(messages)): curr_message = messages[i] buffer.append(curr_message) if isinstance(curr_message, AIMessage): # Fill `dfs` dict if curr_message.tool_calls: for tool_call in curr_message.tool_calls: if tool_call["name"] == "run_sql_query": call_id = str(tool_call["id"]) sql = tool_call["args"]["sql"] dfs[call_id] = {"sql": sql} else: if len(buffer) > 3: # Long thread with no submission at the end. result.append(_truncate_no_df_block(buffer)) buffer = [] elif isinstance(curr_message, ToolMessage): call_id = curr_message.tool_call_id if call_id in dfs and curr_message.artifact is not None and "csv" in curr_message.artifact: # Enrich `dfs` dict with calculation results dfs[call_id]["df"] = curr_message.artifact.get("csv") dfs[call_id]["query_id"] = curr_message.artifact.get("query_id") elif messages[i - 1].tool_calls[0]["name"] == "submit_result": # type: ignore result.append(_truncate_block(dfs, buffer)) buffer = [] else: # For system and human messages result.extend(buffer) buffer = [] return result