databao/executors/lighthouse/history_cleaning.py (54 lines of code) (raw):
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolCall, ToolMessage
from langchain_core.messages.utils import count_tokens_approximately
def _truncate_no_df_block(messages: list[BaseMessage]) -> AIMessage:
"""Returns one AIMessage with only the last message."""
assert messages[-1].type == "ai"
text = f"""Message history was truncated. {len(messages) - 1} messages were deleted.
Here is an answer, which was shown to the user:
""" + str(messages[-1].content)
return AIMessage(content=text)
def _truncate_block(dfs: dict[str, dict[str, str]], messages: list[BaseMessage]) -> AIMessage:
"""Returns one AIMessage with a summary of the block."""
assert isinstance(messages[-1], ToolMessage)
assert isinstance(messages[-2], AIMessage)
df = None
for d in dfs.values():
if d.get("query_id") == messages[-2].tool_calls[0]["args"].get("query_id"):
df = d
break
assert df is not None
tool_call: ToolCall = messages[-2].tool_calls[0]
text = f"""Message history was truncated. {len(messages) - 1} messages were deleted.
This SQL was generated:
```sql
{df["sql"]}
```
Here is an answer, which was shown to the user:
Dataframe:
{df["df"]}
Text:
{tool_call["args"]["result_description"]}
"""
if tool_call["args"].get("visualization_prompt"):
text += f"\n\nVisualization prompt: {tool_call['args']['visualization_prompt']}"
return AIMessage(content=text)
def clean_tool_history(messages: list[BaseMessage], token_limit: int) -> list[BaseMessage]:
"""
If message history exceeds token limit, truncates it.
It removes all intermediate messages and changes a final AI message.
The final message contains SQL, dataframe and text.
Specific for AgentState and ExecuteSubmit graph.
Returns: messages ready to be sent to LLM.
"""
if count_tokens_approximately(messages) < token_limit:
return messages.copy()
assert isinstance(messages[-1], HumanMessage)
dfs: dict[str, dict[str, str]] = {}
buffer = []
result: list[BaseMessage] = []
for i in range(len(messages)):
curr_message = messages[i]
buffer.append(curr_message)
if isinstance(curr_message, AIMessage):
# Fill `dfs` dict
if curr_message.tool_calls:
for tool_call in curr_message.tool_calls:
if tool_call["name"] == "run_sql_query":
call_id = str(tool_call["id"])
sql = tool_call["args"]["sql"]
dfs[call_id] = {"sql": sql}
else:
if len(buffer) > 3:
# Long thread with no submission at the end.
result.append(_truncate_no_df_block(buffer))
buffer = []
elif isinstance(curr_message, ToolMessage):
call_id = curr_message.tool_call_id
if call_id in dfs and curr_message.artifact is not None and "csv" in curr_message.artifact:
# Enrich `dfs` dict with calculation results
dfs[call_id]["df"] = curr_message.artifact.get("csv")
dfs[call_id]["query_id"] = curr_message.artifact.get("query_id")
elif messages[i - 1].tool_calls[0]["name"] == "submit_result": # type: ignore
result.append(_truncate_block(dfs, buffer))
buffer = []
else:
# For system and human messages
result.extend(buffer)
buffer = []
return result