in mtrl/experiment/metaworld.py [0:0]
def collect_trajectory(self, vec_env: VecEnv, num_steps: int) -> None:
multitask_obs = vec_env.reset() # (num_envs, 9, 84, 84)
env_indices = multitask_obs["task_obs"]
episode_reward, episode_step, done = [
np.full(shape=vec_env.num_envs, fill_value=fill_value)
for fill_value in [0.0, 0, True]
] # (num_envs, 1)
for _ in range(num_steps):
with agent_utils.eval_mode(self.agent):
action = self.agent.sample_action(
multitask_obs=multitask_obs, mode="train"
) # (num_envs, action_dim)
next_multitask_obs, reward, done, info = vec_env.step(action)
if self.should_reset_env_manually:
if (episode_step[0] + 1) % self.max_episode_steps == 0:
# we do a +2 because we started the counting from 0 and episode_step is incremented after updating the buffer
next_multitask_obs = vec_env.reset()
episode_reward += reward
# allow infinite bootstrap
for index, env_index in enumerate(env_indices):
done_bool = (
0
if episode_step[index] + 1 == self.max_episode_steps
else float(done[index])
)
self.replay_buffer.add(
multitask_obs["env_obs"][index],
action[index],
reward[index],
next_multitask_obs["env_obs"][index],
done_bool,
env_index=env_index,
)
multitask_obs = next_multitask_obs
episode_step += 1