project/paperbench/paperbench/agents/aisi-basic-agent/utils.py (374 lines of code) (raw):
import asyncio
import json
import logging
import subprocess
import time
from copy import deepcopy
from datetime import datetime
from typing import List, cast
import tiktoken
from inspect_ai._util.constants import HTTP
from inspect_ai._util.hooks import send_telemetry
from inspect_ai._util.interrupt import check_sample_interrupt
from inspect_ai._util.trace import trace_action
from inspect_ai._util.working import report_sample_waiting_time, sample_working_time
from inspect_ai.model._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
from inspect_ai.model._call_tools import disable_parallel_tools, tool_call_view, tools_info
from inspect_ai.model._chat_message import (
ChatMessage,
ChatMessageAssistant,
ChatMessageSystem,
ChatMessageUser,
)
from inspect_ai.model._generate_config import GenerateConfig, active_generate_config
from inspect_ai.model._model import (
active_model,
collapse_consecutive_assistant_messages,
collapse_consecutive_user_messages,
handle_sample_message_limit,
record_model_usage,
resolve_reasoning_history,
resolve_tool_model_input,
tool_result_images_as_user_message,
)
from inspect_ai.model._model_output import ChatCompletionChoice, ModelOutput
from inspect_ai.model._providers.openrouter import OpenRouterError
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
from inspect_ai.tool._tool_def import ToolDef, tool_defs
from openai import LengthFinishReasonError
from openai.types.chat import ChatCompletion
from tenacity import (
RetryCallState,
retry,
retry_if_exception,
stop_after_attempt,
stop_after_delay,
stop_never,
wait_exponential_jitter,
)
from tenacity.stop import StopBaseT
logger = logging.getLogger(__name__)
def handle_message_len(
msg: ChatMessage,
tokenizer,
max_tokens: int,
) -> ChatMessage:
def truncate_string(input_string: str, input_tokens: list, max_tokens: int) -> str:
n_tokens = len(input_tokens)
if n_tokens > max_tokens:
keep_tokens = max_tokens // 2
first_half = tokenizer.decode(input_tokens[:keep_tokens])
second_half = tokenizer.decode(input_tokens[-keep_tokens:])
return first_half + "\n...[content truncated due to length]...\n" + second_half
return input_string
if isinstance(msg.content, str):
item_tokens = tokenizer.encode(msg.content, disallowed_special=())
msg.content = truncate_string(msg.content, item_tokens, max_tokens)
elif isinstance(msg.content, list):
token_lists: list[list[int]] = [] # tokenization of each message
token_counts: list[int] = [] # number of tokens in each message
for item in msg.content:
if item.type == "text":
item_tokens = tokenizer.encode(item.text, disallowed_special=())
token_lists.append(item_tokens)
token_counts.append(len(item_tokens))
elif item.type == "reasoning":
item_tokens = tokenizer.encode(item.reasoning, disallowed_special=())
token_lists.append(item_tokens)
token_counts.append(len(item_tokens))
else:
# Non-text content types don't count towards token limit
token_lists.append([])
token_counts.append(0)
total_tokens = sum(token_counts)
if total_tokens == 0:
return msg # No text content to truncate edge case
# Distribute max_tokens proportionally for text content
tokens_per_item = [
max(1, int((count / total_tokens) * max_tokens)) if count > 0 else 0
for count in token_counts
]
# Apply truncation while preserving content type
new_content = []
for item, item_tokens, max_tokens_for_item in zip(
msg.content, token_lists, tokens_per_item
):
if item.type == "text":
item.text = truncate_string(item.text, item_tokens, max_tokens_for_item)
elif item.type == "reasoning":
item.reasoning = truncate_string(item.reasoning, item_tokens, max_tokens_for_item)
new_content.append(item)
msg.content = new_content
return msg
def get_gpu_generation() -> str | None:
"""Returns the GPU generation, if available."""
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
except Exception:
return None
if result.returncode != 0:
return None
generation = result.stdout.strip().split("\n")
if not generation:
return None
return ", ".join([info.strip() for info in generation])
def append_system_message(messages: list[ChatMessage], message: ChatMessageSystem) -> None:
# find last index of any existing system message
lastIndex = -1
for i in list(reversed(range(0, len(messages)))):
if isinstance(messages[i], ChatMessageSystem):
lastIndex = i
break
# insert it
messages.insert(lastIndex + 1, message)
def prune_messages(
messages: List[ChatMessage], prune_individual: bool = False
) -> List[ChatMessage]:
"""Prune messages to stay within API limits.
Removes older messages while preserving conversation coherence by:
- Keeping all system messages
- Keeping the initial task instruction message
- Removing the oldest 30% of conversation messages
- Ensuring tool messages remain paired with their parent assistant messages
Args:
messages: List of chat messages to prune
Returns:
List[ChatMessage]: Pruned list of messages with preserved conversation flow
"""
# Split messages into system and conversation parts
system_msgs: List[ChatMessage] = [m for m in messages if m.role == "system"]
conversation = [m for m in messages if m.role != "system"]
# Always preserve the first user message (task instructions)
task_msg = next((m for m in conversation if m.role == "user"), None)
# Remove oldest 30% of messages
start_idx = max(1, int(len(conversation) * 0.3)) # Keep 70%
preserved: List[ChatMessage] = [task_msg] if task_msg else []
preserved.extend(conversation[start_idx:])
conversation = preserved
# OAI API requires any messages with `msg.role == "tool"` to be preceded by
# an assistant message with corresponding tool_calls. Our pruning may violate this,
# so we need to clean up tool messages that lost their parent assistant message
valid_messages = []
active_tool_ids = set() # IDs from most recent assistant's tool_calls
for msg in conversation:
if "prompt is too long" in msg.content:
continue
if msg.role == "assistant":
# Track any tool calls from this assistant message
active_tool_ids = {tc.id for tc in (msg.tool_calls or [])}
valid_messages.append(msg)
elif msg.role == "tool" and getattr(msg, "tool_call_id", None) in active_tool_ids:
# Keep tool messages only if they match an active tool call
valid_messages.append(msg)
elif msg.role == "user":
# Reset tool tracking at user messages & keep the message
active_tool_ids = set()
valid_messages.append(msg)
if prune_individual:
# ensure individual messages are not over the context limit
MAX_TOKENS_PER_MESSAGE = 190000 # 200k token limit, minus 10k buffer
# use OAI's 200k tokenizer as an approximation
tokenizer = tiktoken.get_encoding("o200k_base")
valid_messages = [
handle_message_len(msg, tokenizer, MAX_TOKENS_PER_MESSAGE) for msg in valid_messages
]
# Reconstruct pruned conversation
return cast(List[ChatMessage], system_msgs + valid_messages)
def log_rate_limit_retry(context: str, retry_state: RetryCallState) -> None:
logger.log(
HTTP,
f"{context} rate limit retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
)
async def generate_patched(
self,
input: str | list[ChatMessage],
tools: list[Tool] | list[ToolDef] | list[ToolInfo] | list[Tool | ToolDef | ToolInfo] = [],
tool_choice: ToolChoice | None = None,
config: GenerateConfig = GenerateConfig(),
cache: bool | CachePolicy = False,
) -> ModelOutput:
"""Generate output from the model.
Args:
input: Chat message input (if a `str` is passed it is converted
to a `ChatMessageUser`).
tools: Tools available for the model to call.
tool_choice: Directives to the model as to which tools to prefer.
config: Model configuration.
cache: Caching behavior for generate responses (defaults to no caching).
Returns:
ModelOutput
"""
# if we are the default model then enforce message limit if it
# exists (raise an exception if it is exceeded)
is_active_model = self == active_model()
if is_active_model:
handle_sample_message_limit(input)
# base config for this model
base_config = self.config
# if we are the active_model then merge active generate config
if is_active_model:
base_config = base_config.merge(active_generate_config())
# merge passed config
config = base_config.merge(config)
# provide max_tokens from the model api if required
if config.max_tokens is None:
config.max_tokens = self.api.max_tokens_for_config(config)
if config.max_tokens is None:
config.max_tokens = self.api.max_tokens()
# disable parallel tool calls if requested by any of our tools
if disable_parallel_tools(tools):
config.parallel_tool_calls = False
# normalize input to chat
if isinstance(input, str):
input = [ChatMessageUser(content=input)]
# insert any system message provided in config
if config.system_message:
input = [ChatMessageSystem(content=config.system_message)] + input
# enforce concurrency limits
start_time = datetime.now()
working_start = sample_working_time()
async with self._connection_concurrency(config):
from inspect_ai.log._samples import track_active_sample_retries
# generate
with track_active_sample_retries():
output = await _generate(
self=self,
input=input,
tools=tools,
tool_choice=tool_choice,
config=config,
cache=cache,
)
# update the most recent ModelEvent with the actual start/completed
# times as well as a computation of working time (events are
# created _after_ the call to _generate, potentially in response
# to retries, so they need their timestamp updated so it accurately
# reflects the full start/end time which we know here)
from inspect_ai.log._transcript import ModelEvent, transcript
last_model_event = transcript().find_last_event(ModelEvent)
if last_model_event:
last_model_event.timestamp = start_time
last_model_event.working_start = working_start
completed = datetime.now()
last_model_event.completed = completed
last_model_event.working_time = (
output.time if output.time is not None else (completed - start_time).total_seconds()
)
# return output
return output
async def _generate(
self,
input: list[ChatMessage],
tools: list[Tool] | list[ToolDef] | list[ToolInfo] | list[Tool | ToolDef | ToolInfo],
tool_choice: ToolChoice | None,
config: GenerateConfig,
cache: bool | CachePolicy = False,
) -> ModelOutput:
# default to 'auto' for tool_choice (same as underlying model apis)
tool_choice = tool_choice if tool_choice else "auto"
# extract tool defs if we can
tdefs = tool_defs([tool for tool in tools if not isinstance(tool, ToolInfo)])
# resolve all tools into tool_info
tools = tools_info(tools)
# if we have a specific tool selected then filter out the others
if isinstance(tool_choice, ToolFunction):
tools = [tool for tool in tools if tool.name == tool_choice.name]
# if tool_choice is "none" or if there are no tools then fully purge
# the tools (as some models (e.g. openai and mistral) get confused
# if you pass them tool definitions along with tool_choice == "none"
# (they both 'semi' use the tool by placing the arguments in JSON
# in their output!). on the other hand, anthropic actually errors if
# there are tools anywhere in the message stream and no tools defined.
if tool_choice == "none" or len(tools) == 0:
# allow model providers to implement a tools_required() method to
# force tools to be passed (we need this for anthropic)
if not self.api.tools_required():
tools = []
tool_choice = "none"
# handle reasoning history
input = resolve_reasoning_history(input, config, self.api)
# apply any tool model_input handlers
input = resolve_tool_model_input(tdefs, input)
# break tool image content out into user messages if the model doesn't
# support tools returning images
if not self.api.tool_result_images():
input = tool_result_images_as_user_message(input)
# optionally collapse *consecutive* messages into one -
# (some apis e.g. anthropic require this)
if self.api.collapse_user_messages():
input = collapse_consecutive_user_messages(input)
if self.api.collapse_assistant_messages():
input = collapse_consecutive_assistant_messages(input)
# retry for transient http errors:
# - no default timeout or max_retries (try forever)
# - exponential backoff starting at 3 seconds (will wait 25 minutes
# on the 10th retry,then will wait no longer than 30 minutes on
# subsequent retries)
if config.max_retries is not None and config.timeout is not None:
stop: StopBaseT = stop_after_attempt(config.max_retries) | stop_after_delay(config.timeout)
elif config.max_retries is not None:
stop = stop_after_attempt(config.max_retries)
elif config.timeout is not None:
stop = stop_after_delay(config.timeout)
else:
stop = stop_never
def before_sleep(retry_state: RetryCallState) -> None:
wait_time = retry_state.next_action.sleep
self.total_retry_time += wait_time
log_rate_limit_retry(self.api.model_name, retry_state)
@retry(
wait=wait_exponential_jitter(initial=3, max=(2 * 60), jitter=3),
retry=retry_if_exception(self.should_retry),
stop=stop,
before_sleep=before_sleep,
)
async def generate() -> ModelOutput:
check_sample_interrupt()
cache_entry: CacheEntry | None
if cache:
if isinstance(cache, CachePolicy):
policy = cache
else:
policy = CachePolicy()
cache_entry = CacheEntry(
base_url=self.api.base_url,
config=deepcopy(config),
input=input,
model=str(self),
policy=policy,
tool_choice=tool_choice,
tools=tools, # type: ignore
)
existing = cache_fetch(cache_entry)
if isinstance(existing, ModelOutput):
self._record_model_interaction(
input=input,
tools=tools,
tool_choice=tool_choice,
config=config,
cache="read",
output=existing,
call=None,
)
return existing
else:
cache_entry = None
# verify that model apis are allowed
self.verify_model_apis()
# record the interaction before the call to generate
# (we'll update it with the results once we have them)
complete = self._record_model_interaction(
input=input,
tools=tools,
tool_choice=tool_choice,
config=config,
cache="write" if cache else None,
)
with trace_action(logger, "Model", f"generate ({str(self)})"):
time_start = time.monotonic()
try:
# Apply timeout at the API call level if configured
if config.timeout is not None:
async with asyncio.timeout(config.timeout):
result = await self.api.generate(
input=input,
tools=tools,
tool_choice=tool_choice,
config=config,
)
else:
result = await self.api.generate(
input=input,
tools=tools,
tool_choice=tool_choice,
config=config,
)
except asyncio.TimeoutError:
logger.warning(f"API call timed out after {config.timeout} seconds")
# Create a ModelOutput that indicates the model timed out
message = ChatMessageAssistant(
content="Model exceeded time limit during completion",
source="generate",
)
output = ModelOutput(
model=str(self),
choices=[
ChatCompletionChoice(
message=message,
stop_reason="stop", # Using 'stop' since timeout isn't a valid stop_reason
)
],
time=config.timeout, # We know exactly how long it took - the timeout value
)
result = (output, None)
except OpenRouterError as e:
if (
"exceed context limit" in str(e)
or "context length" in str(e)
or "too long" in str(e)
):
error_completion = ChatCompletion(
choices=[], id="", created=0, model="", object="chat.completion"
)
error = LengthFinishReasonError(completion=error_completion)
# Only add special marker if "too long" is in the error message
if "too long" in str(e):
error.args = ("PRUNE_INDIVIDUAL_MESSAGES: Message is too long",)
raise error
else:
raise e
finally:
time_elapsed = time.monotonic() - time_start
if isinstance(result, tuple):
output, call = result
else:
output = result
call = None
# raise error
if isinstance(output, Exception):
complete(output, call)
# Wrap the error in a runtime error which will show the
# request which caused the error
error = repr(output)
request = json.dumps(call.request, indent=2) if call is not None else ""
error_message = f"{error}\n\nRequest:\n{request}"
raise RuntimeError(error_message)
# update output with time (call.time captures time spent
# on the actual request that succeeds w/ status 200)
if call and call.time is not None:
output.time = call.time
else:
output.time = time_elapsed
# add views to tool calls
for choice in output.choices:
for tool_call in choice.message.tool_calls or []:
tool_call.view = tool_call_view(tool_call, tdefs)
# complete the transcript event
complete(output, call)
# record usage
if output.usage:
# record usage
record_model_usage(f"{self}", output.usage)
# send telemetry if its hooked up
await send_telemetry(
"model_usage",
json.dumps(dict(model=str(self), usage=output.usage.model_dump())),
)
if cache and cache_entry:
cache_store(entry=cache_entry, output=output)
return output
# call the model (this will so retries, etc., so report waiting time
# as elapsed time - actual time for successful model call)
time_start = time.monotonic()
model_output = await generate()
total_time = time.monotonic() - time_start
if model_output.time:
report_sample_waiting_time(total_time - model_output.time)
# return results
return model_output