in mtrl/agent/distral.py [0:0]
def sample_action(self, multitask_obs: ObsType, modes: List[str]) -> np.ndarray:
"""Used during training"""
obs = multitask_obs["env_obs"]
env_index = multitask_obs["task_obs"]
actions = [
self.task_agents[self.task_index_to_agent_index[index]].sample_action(
multitask_obs={
"env_obs": obs[self.task_index_to_agent_index[index]],
"task_obs": torch.LongTensor(
[
[index],
]
), # not used in the actor.
},
modes=modes,
)
for index in env_index.numpy()
]
actions = np.concatenate(actions, axis=0)
return actions