agents/utils/history_util.py (99 lines of code) (raw):
"""Message history with token tracking and prompt caching."""
from typing import Any
class MessageHistory:
"""Manages chat history with token tracking and context management."""
def __init__(
self,
model: str,
system: str,
context_window_tokens: int,
client: Any,
enable_caching: bool = True,
):
self.model = model
self.system = system
self.context_window_tokens = context_window_tokens
self.messages: list[dict[str, Any]] = []
self.total_tokens = 0
self.enable_caching = enable_caching
self.message_tokens: list[tuple[int, int]] = (
[]
) # List of (input_tokens, output_tokens) tuples
self.client = client
# set initial total tokens to system prompt
try:
system_token = (
self.client.messages.count_tokens(
model=self.model,
system=self.system,
messages=[{"role": "user", "content": "test"}],
).input_tokens
- 1
)
except Exception:
system_token = len(self.system) / 4
self.total_tokens = system_token
async def add_message(
self,
role: str,
content: str | list[dict[str, Any]],
usage: Any | None = None,
):
"""Add a message to the history and track token usage."""
if isinstance(content, str):
content = [{"type": "text", "text": content}]
message = {"role": role, "content": content}
self.messages.append(message)
if role == "assistant" and usage:
total_input = (
usage.input_tokens
+ getattr(usage, "cache_read_input_tokens", 0)
+ getattr(usage, "cache_creation_input_tokens", 0)
)
output_tokens = usage.output_tokens
current_turn_input = total_input - self.total_tokens
self.message_tokens.append((current_turn_input, output_tokens))
self.total_tokens += current_turn_input + output_tokens
def truncate(self) -> None:
"""Remove oldest messages when context window limit is exceeded."""
if self.total_tokens <= self.context_window_tokens:
return
TRUNCATION_NOTICE_TOKENS = 25
TRUNCATION_MESSAGE = {
"role": "user",
"content": [
{
"type": "text",
"text": "[Earlier history has been truncated.]",
}
],
}
def remove_message_pair():
self.messages.pop(0)
self.messages.pop(0)
if self.message_tokens:
input_tokens, output_tokens = self.message_tokens.pop(0)
self.total_tokens -= input_tokens + output_tokens
while (
self.message_tokens
and len(self.messages) >= 2
and self.total_tokens > self.context_window_tokens
):
remove_message_pair()
if self.messages and self.message_tokens:
original_input_tokens, original_output_tokens = (
self.message_tokens[0]
)
self.messages[0] = TRUNCATION_MESSAGE
self.message_tokens[0] = (
TRUNCATION_NOTICE_TOKENS,
original_output_tokens,
)
self.total_tokens += (
TRUNCATION_NOTICE_TOKENS - original_input_tokens
)
def format_for_api(self) -> list[dict[str, Any]]:
"""Format messages for Claude API with optional caching."""
result = [
{"role": m["role"], "content": m["content"]} for m in self.messages
]
if self.enable_caching and self.messages:
result[-1]["content"] = [
{**block, "cache_control": {"type": "ephemeral"}}
for block in self.messages[-1]["content"]
]
return result