def sample_action()

in mtrl/agent/distral.py [0:0]


    def sample_action(self, multitask_obs: ObsType, modes: List[str]) -> np.ndarray:
        """Used during training"""
        obs = multitask_obs["env_obs"]
        env_index = multitask_obs["task_obs"]
        actions = [
            self.task_agents[self.task_index_to_agent_index[index]].sample_action(
                multitask_obs={
                    "env_obs": obs[self.task_index_to_agent_index[index]],
                    "task_obs": torch.LongTensor(
                        [
                            [index],
                        ]
                    ),  # not used in the actor.
                },
                modes=modes,
            )
            for index in env_index.numpy()
        ]
        actions = np.concatenate(actions, axis=0)
        return actions