in mtrl/experiment/experiment.py [0:0]
def __init__(self, config: ConfigType, experiment_id: str = "0"):
"""Experiment Class to manage the lifecycle of a model.
Args:
config (ConfigType):
experiment_id (str, optional): Defaults to "0".
"""
self.id = experiment_id
self.config = config
self.device = torch.device(self.config.setup.device)
self.get_env_metadata = get_env_metadata
self.envs, self.env_metadata = self.build_envs()
key = "ordered_task_list"
if key in self.env_metadata and self.env_metadata[key]:
ordered_task_dict = {
task: index for index, task in enumerate(self.env_metadata[key])
}
else:
ordered_task_dict = {}
key = "envs_to_exclude_during_training"
if key in self.config.experiment and self.config.experiment[key]:
self.envs_to_exclude_during_training = {
ordered_task_dict[task] for task in self.config.experiment[key]
}
print(
f"Excluding the following environments: {self.envs_to_exclude_during_training}"
)
else:
self.envs_to_exclude_during_training = set()
self.action_space = self.env_metadata["action_space"]
assert self.action_space.low.min() >= -1
assert self.action_space.high.max() <= 1
self.env_obs_space = self.env_metadata["env_obs_space"]
env_obs_shape = self.env_obs_space.shape
action_shape = self.action_space.shape
self.config = prepare_config(config=self.config, env_metadata=self.env_metadata)
self.agent = hydra.utils.instantiate(
self.config.agent.builder,
env_obs_shape=env_obs_shape,
action_shape=action_shape,
action_range=[
float(self.action_space.low.min()),
float(self.action_space.high.max()),
],
device=self.device,
)
self.video_dir = utils.make_dir(
os.path.join(self.config.setup.save_dir, "video")
)
self.model_dir = utils.make_dir(
os.path.join(self.config.setup.save_dir, "model")
)
self.buffer_dir = utils.make_dir(
os.path.join(self.config.setup.save_dir, "buffer")
)
self.video = video.VideoRecorder(
self.video_dir if self.config.experiment.save_video else None
)
self.replay_buffer = hydra.utils.instantiate(
self.config.replay_buffer,
device=self.device,
env_obs_shape=env_obs_shape,
task_obs_shape=(1,),
action_shape=action_shape,
)
self.start_step = 0
should_resume_experiment = self.config.experiment.should_resume
if should_resume_experiment:
self.start_step = self.agent.load_latest_step(model_dir=self.model_dir)
self.replay_buffer.load(save_dir=self.buffer_dir)
self.logger = Logger(
self.config.setup.save_dir,
config=self.config,
retain_logs=should_resume_experiment,
)
self.max_episode_steps = self.env_metadata[
"max_episode_steps"
] # maximum steps that the agent can take in one environment.
self.startup_logs()