evals/elsuite/hr_ml_agent_bench/environment.py (283 lines of code) (raw):
"""
This file defines the `Environment` class, which manages the agent's workspace, including files,
datasets, and other resources.
Note: This file is adapted from MLAgentBench with minimal edits made. The original file can be
found at: https://github.com/snap-stanford/MLAgentBench/blob/main/MLAgentBench/environment.py.
"""
import copy
import fnmatch
import json
import os
import shutil
import signal
import time
from logging import getLogger
from multiprocessing import active_children
from pathlib import Path
from traceback import format_exception
from typing import Optional
from dacite import from_dict
from evals.elsuite.hr_ml_agent_bench.high_level_actions import HIGH_LEVEL_ACTIONS
from evals.elsuite.hr_ml_agent_bench.low_level_actions import LOW_LEVEL_ACTIONS
from evals.elsuite.hr_ml_agent_bench.prepare_task import get_research_problem, prepare_task
from evals.elsuite.hr_ml_agent_bench.schema import (
Action,
EnhancedJSONEncoder,
EnvException,
LLMError,
Step,
TooLongPromptError,
Trace,
)
from evals.solvers.solver import Solver
logger = getLogger(__name__)
class Environment:
def __init__(
self,
log_dir: Path,
work_dir: Path,
task: str,
python_command: str,
resume: bool,
resume_step: int,
device: int,
max_steps: int,
max_time: int,
solver: Solver,
):
self.log_dir = log_dir
self.work_dir = work_dir
self.python_command = python_command
self.resume = resume
self.resume_step = resume_step
self.device = device
self.max_steps = max_steps
self.max_time = max_time
self.solver = solver
self._setup_log_dir()
self._benchmark_folder_name = task
self._research_problem = get_research_problem(task)
self._read_only_files = []
self._initialize_task_env() # set up work dir and log dir
self._action_infos = {t.name: t for t in LOW_LEVEL_ACTIONS + HIGH_LEVEL_ACTIONS}
self._static_kwargs_for_tools = {
"device": self.device,
"python": self.python_command,
"work_dir": self.work_dir,
"read_only_files": self.read_only_files,
"research_problem": self.research_problem,
}
self._trace = self._initialize_trace()
self._start_time = time.time()
############################## getters ########################################
@property
def research_problem(self):
return self._research_problem
@property
def benchmark_folder_name(self):
return self._benchmark_folder_name
@property
def read_only_files(self):
return self._read_only_files
@property
def action_infos(self):
return self._action_infos
@property
def static_kwargs_for_tools(self):
return self._static_kwargs_for_tools
@property
def trace(self):
return copy.deepcopy(self._trace)
@property
def start_time(self):
return self._start_time
############################## internal functions ########################################
def _setup_log_dir(self):
# set up log dir
if os.path.exists(self.log_dir):
logger.info(f"log_dir {self.log_dir} already exists")
else:
os.makedirs(self.log_dir)
if os.path.exists(os.path.join(self.log_dir, "tool_logs")):
logger.info(f"tools_log_dir {os.path.join(self.log_dir, 'tool_logs')} already exists")
else:
os.makedirs(os.path.join(self.log_dir, "tool_logs"))
if os.path.exists(os.path.join(self.log_dir, "traces")):
logger.info(f"tools_log_dir {os.path.join(self.log_dir, 'traces')} already exists")
else:
os.makedirs(os.path.join(self.log_dir, "traces"))
def _initialize_task_env(self):
work_dir = self.work_dir
# remove the workspace folder if it exists
if os.path.exists(work_dir):
shutil.rmtree(work_dir)
benchmark_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"benchmarks",
self.benchmark_folder_name,
)
# prepare if there is a prepare.py and it has not been prepared
prepare_task(benchmark_dir, self.python_command)
# copy the benchmarks folder to work_dir
if os.path.exists(os.path.join(benchmark_dir, "env")):
shutil.copytree(os.path.join(benchmark_dir, "env"), work_dir, symlinks=True)
# find all read only files
if os.path.exists(os.path.join(benchmark_dir, "scripts", "read_only_files.txt")):
ignore_files = (
open(os.path.join(benchmark_dir, "scripts", "read_only_files.txt"), "r")
.read()
.split("\n")
)
for path, subdirs, files in os.walk(os.path.join(work_dir)):
relpath = os.path.relpath(path, work_dir)
# filter out the files that are read only
filenames = [os.path.join(relpath, filename) for filename in files]
for ignore in ignore_files:
ignore_filenames = [n for n in filenames if fnmatch.fnmatch(n, ignore)]
self.read_only_files.extend(ignore_filenames)
# init backup folder and remove all content if it exists
if os.path.exists(os.path.join(work_dir, "backup")):
shutil.rmtree(os.path.join(work_dir, "backup"))
os.mkdir(os.path.join(work_dir, "backup"))
if self.resume:
shutil.rmtree(work_dir)
resume_dir = os.path.join(
self.resume,
"env_log",
"traces",
f"step_{self.resume_step}_files",
)
logger.info(f"Restoring workspace ing from {resume_dir}")
shutil.copytree(resume_dir, work_dir, symlinks=True)
if not os.path.exists(os.path.join(work_dir, "backup")):
os.mkdir(os.path.join(work_dir, "backup"))
def _initialize_trace(self):
if self.resume:
logger.info(f"Restoring trace from {self.resume}")
prev_trace = from_dict(
data_class=Trace,
data=json.load(open(os.path.join(self.resume, "env_log", "trace.json"), "r")),
)
logger.info(f"Resetting trace to step {self.resume_step}")
steps = prev_trace.steps[: self.resume_step + 1]
t = steps[-1].timestamp
low_level_steps = [s for s in prev_trace.low_level_steps if s.timestamp < t]
trace = Trace(
steps=steps,
low_level_steps=low_level_steps,
action_infos=self.action_infos,
task_description=self.research_problem,
)
else:
trace = Trace(
steps=[],
low_level_steps=[],
action_infos=self.action_infos,
task_description=self.research_problem,
)
return trace
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
# save error message
active = active_children()
logger.info(f"Active Children: {len(active)}")
# terminate all active children
for child in active:
child.terminate()
# block until all children have closed
for child in active:
child.join()
# report active children
active = active_children()
logger.info(f"Active Children: {len(active)}")
if traceback is not None:
logger.info("Error message saved in error.txt")
open(os.path.join(self.log_dir, "error.txt"), "w").write(
"".join(format_exception(exc_type, exc_value, traceback))
)
open(os.path.join(self.log_dir, "overall_time.txt"), "w").write(
str(time.time() - self.start_time)
)
################################# public functions ########################################
def is_done(self):
"""Check if the task has reached a final state, either by reaching the maximum steps or time, or because the agent has submitted a final answer."""
curr_step = len(self.trace.steps)
# check if any step is final answer
any_final_answer = any([s.action.name == "Final Answer" for s in self.trace.steps])
return (
curr_step >= self.max_steps
or any_final_answer
or time.time() - self.start_time > self.max_time
)
def execute(self, action: Action, max_seconds_per_step: Optional[int] = None) -> str:
"""Execute an action and return the observation."""
trace = self._trace
curr_step = len(trace.steps)
action_name = action.name
action_input = action.args
if action_name == "Final Answer":
observation = "end"
elif self.is_done():
observation = "The environment has shut down because the maximum number of steps or time has been reached. Please submit your final answer."
elif action_name not in list(self.action_infos.keys()):
actions = ", ".join(self.action_infos.keys())
observation = f"Invalid action: {action_name}. Action did not execute. Please use one of the following actions:\n{actions}"
else:
# execute the action and get the observation
log_file = os.path.join(
os.path.join(self.log_dir, "tool_logs"),
f"step_{curr_step}_tool_log.log",
)
usage = ",\n ".join(
[f"{k}: [{v}]" for k, v in self.action_infos[action_name].usage.items()]
)
usage = f"""{{
{usage}
}}"""
invalid_action_error = f"""No valid action found! Please ensure you're executing a valid action with json inputs. For example, to execute the `List Files` action, you would write:
Action: List Files
Action Input: {{
"dir_path": "."
}}
Likewise, the input for the action `{action_name}` needs to be valid json with proper entries. Please try again with the correct arguments:
Action: {action_name}
Action Input: {usage}"""
if isinstance(action_input, dict):
try:
if max_seconds_per_step is not None:
signal.signal(signal.SIGALRM, _signal_handler)
signal.alarm(max_seconds_per_step)
observation = self.action_infos[action_name].function(
**action_input,
log_file=log_file,
trace=trace,
**self.static_kwargs_for_tools,
solver=self.solver,
)
except TooLongPromptError:
observation = "EnvError: too long input for the tool"
except LLMError as e:
observation = "LLMError: " + e.message
except TimeoutError:
observation = f"TimeoutError: action execution time exceeded the maximum time limit of {max_seconds_per_step} seconds!"
except EnvException as e:
observation = "EnvError: " + e.message
except TypeError as e:
logger.info(f"Step: {curr_step}")
logger.info(e)
logger.info(action_input)
observation = "EnvError: " + invalid_action_error
except Exception as e:
# should not happen
logger.info(f"Step: {curr_step}")
logger.info(e)
if "Connection aborted." in str(e):
raise Exception("Connection aborted for crfm")
observation = f"EnvError: Error executing {action_name}."
finally:
if max_seconds_per_step is not None:
signal.alarm(0) # disable the alarm
else:
observation = invalid_action_error
step_time = time.time()
trace.steps.append(Step(action, observation, step_time))
self.save(curr_step)
return observation
def save(self, curr_step):
"""Save the trace and snapshot of the workspace folder"""
with open(os.path.join(self.log_dir, "trace.json"), "w") as f:
json.dump(self.trace, f, indent=4, cls=EnhancedJSONEncoder)
##### save a snapshot of the current step
save_folder = os.path.join(self.log_dir, f"traces/step_{curr_step}_files")
if os.path.exists(save_folder):
shutil.rmtree(save_folder)
os.makedirs(save_folder)
# save files in the folder that are not read only
for path, subdirs, files in os.walk(os.path.join(self.work_dir)):
relpath = os.path.relpath(path, self.work_dir)
dest = os.path.join(save_folder, relpath)
for file_name in files:
file_path = os.path.join(relpath, file_name)
if file_path not in self.read_only_files:
if not os.path.exists(dest):
os.makedirs(dest)
shutil.copyfile(
os.path.join(self.work_dir, file_path),
os.path.join(save_folder, file_path),
)
############## for logging convenience ##############
def get_task_description(self):
return self.research_problem, self.benchmark_folder_name
@property
def low_level_actions(self):
return list(filter(lambda x: x.is_primitive, self.action_infos.values()))
@property
def high_level_actions(self):
return list(filter(lambda x: not x.is_primitive, self.action_infos.values()))
def print_action(self, entries):
return "".join([k + ": " + v for k, v in entries.items()])
def _signal_handler(signum, frame):
raise TimeoutError("Time's up! The action exceeded the maximum time limit and terminated early")