in mtrl/experiment/multitask.py [0:0]
def run(self):
"""Run the experiment."""
exp_config = self.config.experiment
vec_env = self.envs["train"]
episode_reward, episode_step, done = [
np.full(shape=vec_env.num_envs, fill_value=fill_value)
for fill_value in [0.0, 0, True]
] # (num_envs, 1)
if "success" in self.metrics_to_track:
success = np.full(shape=vec_env.num_envs, fill_value=0.0)
info = {}
assert self.start_step >= 0
episode = self.start_step // self.max_episode_steps
start_time = time.time()
multitask_obs = vec_env.reset() # (num_envs, 9, 84, 84)
env_indices = multitask_obs["task_obs"]
train_mode = ["train" for _ in range(vec_env.num_envs)]
for step in range(self.start_step, exp_config.num_train_steps):
if step % self.max_episode_steps == 0: # todo
if step > 0:
if "success" in self.metrics_to_track:
success = (success > 0).astype("float")
for index, _ in enumerate(env_indices):
self.logger.log(
f"train/success_env_index_{index}",
success[index],
step,
)
self.logger.log("train/success", success.mean(), step)
for index, env_index in enumerate(env_indices):
self.logger.log(
f"train/episode_reward_env_index_{index}",
episode_reward[index],
step,
)
self.logger.log(f"train/env_index_{index}", env_index, step)
self.logger.log("train/duration", time.time() - start_time, step)
start_time = time.time()
self.logger.dump(step)
# evaluate agent periodically
if step % exp_config.eval_freq == 0:
self.evaluate_vec_env_of_tasks(
vec_env=self.envs["eval"], step=step, episode=episode
)
if exp_config.save.model:
self.agent.save(
self.model_dir,
step=step,
retain_last_n=exp_config.save.model.retain_last_n,
)
if exp_config.save.buffer.should_save:
self.replay_buffer.save(
self.buffer_dir,
size_per_chunk=exp_config.save.buffer.size_per_chunk,
num_samples_to_save=exp_config.save.buffer.num_samples_to_save,
)
episode += 1
episode_reward = np.full(shape=vec_env.num_envs, fill_value=0.0)
if "success" in self.metrics_to_track:
success = np.full(shape=vec_env.num_envs, fill_value=0.0)
self.logger.log("train/episode", episode, step)
if step < exp_config.init_steps:
action = np.asarray(
[self.action_space.sample() for _ in range(vec_env.num_envs)]
) # (num_envs, action_dim)
else:
with agent_utils.eval_mode(self.agent):
# multitask_obs = {"env_obs": obs, "task_obs": env_indices}
action = self.agent.sample_action(
multitask_obs=multitask_obs,
modes=[
train_mode,
],
) # (num_envs, action_dim)
# run training update
if step >= exp_config.init_steps:
num_updates = (
exp_config.init_steps if step == exp_config.init_steps else 1
)
for _ in range(num_updates):
self.agent.update(self.replay_buffer, self.logger, step)
next_multitask_obs, reward, done, info = vec_env.step(action)
if self.should_reset_env_manually:
if (episode_step[0] + 1) % self.max_episode_steps == 0:
# we do a +2 because we started the counting from 0 and episode_step is incremented after updating the buffer
next_multitask_obs = vec_env.reset()
episode_reward += reward
if "success" in self.metrics_to_track:
success += np.asarray([x["success"] for x in info])
# allow infinite bootstrap
for index, env_index in enumerate(env_indices):
done_bool = (
0
if episode_step[index] + 1 == self.max_episode_steps
else float(done[index])
)
if index not in self.envs_to_exclude_during_training:
self.replay_buffer.add(
multitask_obs["env_obs"][index],
action[index],
reward[index],
next_multitask_obs["env_obs"][index],
done_bool,
task_obs=env_index,
)
multitask_obs = next_multitask_obs
episode_step += 1
self.replay_buffer.delete_from_filesystem(self.buffer_dir)
self.close_envs()