in project/paperbench/paperbench/agents/aisi-basic-agent/_basic_agent_plus.py [0:0]
def basic_agent_loop() -> Solver:
async def solve(state: TaskState, generate: Generate) -> TaskState:
# resolve message_limit -- prefer parameter then fall back to task
state.message_limit = message_limit or state.message_limit
# resolve token limit
state.token_limit = token_limit or state.token_limit
# track attempts
attempts = 0
num_steps = 0
start_time = time.time()
model = get_model()
setattr(model, "total_retry_time", 0)
setattr(model, "generate", generate_patched)
# main loop (state.completed checks message_limit and token_limit)
while not state.completed:
# Proactively prune messages if we're approaching 1000 messages for Claude 3.5 Sonnet on Anthropic API
if "sonnet" in model.name.lower() and len(state.messages) > 900:
state.messages = prune_messages(state.messages)
# Finish depending on step and time limits. If both are set, both must be met.
num_steps += 1
# NOTE log.warning is necessary here for this to appear in inspect.log, which is necessary so we can know the agents runtime
logger.warning(
f"total runtime: {round(time.time() - start_time, 2)}, total productive runtime: {round(time.time() - start_time - model.total_retry_time, 2)}, total retry time: {round(model.total_retry_time, 2)}"
)
over_time_limit = (
(time.time() - start_time - model.total_retry_time > real_time_limit)
if real_time_limit is not None
else False
)
if real_time_limit is not None and over_time_limit:
state.completed = True
break
# Send progress update every 5 steps
if num_steps % 5 == 0:
if real_time_limit is not None:
elapsed_time = time.time() - start_time - model.total_retry_time
periodic_msg = f"Info: {format_progress_time(elapsed_time)} time elapsed out of {format_progress_time(real_time_limit)}. Remember, you only have to stop working when the time limit has been reached."
else:
elapsed_time = time.time() - start_time
periodic_msg = f"Info: {format_progress_time(elapsed_time)} time elapsed"
periodic_msg += "\n\nNote: Don't forget to git commit regularly!"
state.messages.append(ChatMessageUser(content=periodic_msg))
length_finish_error = False
prune_individual = False
try:
# Calculate timeout in seconds for model.generate if real_time_limit is set
generate_timeout = None
if real_time_limit is not None:
generate_timeout = int(
real_time_limit - (time.time() - start_time - model.total_retry_time)
)
# generate output and append assistant message
state.output = await model.generate(
self=model,
input=state.messages,
tools=state.tools,
cache=cache,
config=GenerateConfig(timeout=generate_timeout),
)
state.messages.append(state.output.message)
except (LengthFinishReasonError, IndexError) as e:
length_finish_error = True
if "PRUNE_INDIVIDUAL_MESSAGES" in str(e):
prune_individual = True
except JSONDecodeError:
state.messages.append(ChatMessageUser(content="The JSON returned was invalid."))
continue
# Handle context length overflow by pruning messages
if length_finish_error or state.output.stop_reason == "model_length":
logger.warning("context length overflow")
state.messages = prune_messages(
state.messages, prune_individual=prune_individual
)
continue
# resolve tools calls (if any)
if state.output.message.tool_calls:
# For each tool call, use timeout equal to the time remaining on this task
timeout = None
if real_time_limit is not None:
timeout = int(
real_time_limit - (time.time() - start_time - model.total_retry_time)
)
# call tool functions
try:
async with asyncio.timeout(timeout):
tool_results = await call_tools(
state.output.message, state.tools, max_output=max_tool_output
)
except asyncio.TimeoutError:
state.messages.append(
ChatMessageUser(content="Timeout: The tool call timed out.")
)
state.completed = True
break
state.messages.extend(tool_results)
# was an answer submitted?
answer = submission(tool_results)
if answer:
# set the output to the answer for scoring
state.output.completion = answer
# exit if we are at max_attempts
attempts += 1
if attempts >= max_attempts:
state.completed = True
break
# exit if the submission is successful
answer_scores = await score(state)
if score_value_fn(answer_scores[0].value) == 1.0:
state.completed = True
break
# otherwise notify the model that it was incorrect and continue
else:
response_message = (
incorrect_message(state, answer_scores)
if callable(incorrect_message)
else incorrect_message
)
state.messages.append(ChatMessageUser(content=response_message))
# no tool calls, urge the model to continue
else:
state.messages.append(ChatMessageUser(content=continue_message))
return state
return solve