evals/elsuite/hr_ml_agent_bench/autoeval.py (172 lines of code) (raw):
import json
import time
from dataclasses import dataclass, replace
from logging import getLogger
from pathlib import Path
from evals.elsuite.hr_ml_agent_bench.actions import get_action, is_valid_action
from evals.elsuite.hr_ml_agent_bench.auto_marking import EvaluationResult, grade_submission
from evals.elsuite.hr_ml_agent_bench.environment import Environment
from evals.elsuite.hr_ml_agent_bench.prompts import get_task_description
from evals.elsuite.hr_ml_agent_bench.schema import ActionInfo
from evals.solvers.solver import Solver
from evals.task_state import Message, TaskState
logger = getLogger(__name__)
@dataclass(frozen=True)
class Step:
step_idx: int
action: dict[str, str]
observation: str
@dataclass(frozen=True)
class TaskStateMetadata:
history_steps: tuple[Step, ...]
actions: dict[str, ActionInfo]
max_steps_in_context: int
max_retries: int
max_steps: int
log_dir: Path
env: Environment
@dataclass(frozen=True)
class FunctionCall:
name: str
args: dict[str, str]
def run(
solver: Solver,
task_name: str,
research_problem: str,
log_dir: Path,
work_dir: Path,
max_steps: int,
max_time: int,
max_seconds_per_step: int,
device: int = 0,
python_command: str = "python",
resume: bool = False,
resume_step: int = 0,
max_steps_in_context: int = 3,
max_retries: int = 5,
) -> EvaluationResult:
"""Evaluates the solver on a given task."""
env = Environment(
log_dir=log_dir / "env_log",
work_dir=work_dir / task_name,
task=task_name,
python_command=python_command,
resume=resume,
resume_step=resume_step,
device=device,
max_steps=max_steps,
max_time=max_time,
solver=solver,
)
task_description = get_task_description(research_problem)
logger.info(task_description)
messages = [
Message(
role="system",
content=f"You have a maximum of {max_steps} steps to solve the task. "
f"Each step is subject to a maximum time limit of {max_seconds_per_step} "
f"seconds. Additionally, your entire attempt is subject to a maximum "
f"time limit of {max_time} seconds.",
),
]
task_state = TaskState(
task_description=task_description,
messages=messages,
current_state=TaskStateMetadata(
history_steps=(),
actions=env.action_infos,
max_steps_in_context=max_steps_in_context,
max_retries=max_retries,
max_steps=max_steps,
log_dir=log_dir,
env=env,
),
)
start_time = time.time()
for step in range(max_steps):
time_elapsed = time.time() - start_time
time_remaining = max_time - time_elapsed
task_state = replace(
task_state,
messages=task_state.messages
+ [
Message(
role="system",
content=f"You have {time_remaining:.2f} seconds and {max_steps - step} steps remaining.",
),
],
)
result = solver(task_state)
new_messages = [Message(role="assistant", content=result.output)]
action = get_action(result.output)
if action is None or not is_valid_action(action):
logger.info(
f"Step {step}: [invalid response; no action executed]\n\nAssistant:\n{result.output}"
)
new_messages = task_state.messages + [
Message(
role="assistant",
content=result.output,
),
Message(
role="system",
content="No valid action found! The list of valid actions was specified at the start; please pick an action from that list.",
),
]
task_state = TaskState(
task_description=task_state.task_description,
messages=new_messages,
current_state=task_state.current_state,
)
continue
logger.info(f"\n\nAction: {action.name}\nAction Input: {json.dumps(action.args)}\n")
new_max_seconds_per_step = min(max_seconds_per_step, time_remaining)
if attempted_to_use_stable_baselines(result.output):
observation = "The action has been terminated because it used the `stable_baselines` "
"library, which is not allowed. Please write your implementation from scratch."
elif time_remaining > 0:
observation = env.execute(action, max_seconds_per_step=new_max_seconds_per_step)
else:
observation = "Time's up! Your attempt has now exceeded the maximum time limit "
f"of {max_time} seconds. The last action attempted was not executed, "
"and your current solution will be graded."
new_messages = task_state.messages + [
Message(
role="assistant",
content=result.output,
),
Message(
role="system",
content=f"Observation:\n\n```\n{observation}\n```",
),
]
new_history_steps = task_state.current_state.history_steps + (
{
"step_idx": step,
"action": {
"Action": action.name,
"Action Input": json.dumps(action.args, indent=4),
},
"observation": observation,
},
)
new_task_state_metadata = replace(
task_state.current_state,
history_steps=new_history_steps,
)
task_state = TaskState(
task_description=task_state.task_description,
messages=new_messages,
current_state=new_task_state_metadata,
)
logger.info(f"\n\nObservation:\n```\n{observation}\n```\n")
env.save(step)
if env.is_done():
break
env.save("final")
result = grade_submission(log_dir=log_dir, task_name=task_name)
return result
def attempted_to_use_stable_baselines(s: str) -> bool:
s = s.lower() # be case-insensitive
if "stable" in s and "baseline" in s:
return True
return False