def forward()

in mtrl/agent/components/moe_layer.py [0:0]


    def forward(self, task_info: TaskInfo) -> TensorType:
        if self.should_use_task_encoding:
            emb = task_info.encoding
            if self.should_detach_task_encoding:
                emb = emb.detach()  # type: ignore[union-attr]
        else:
            env_index = task_info.env_index
            if len(env_index.shape) == 2:
                env_index = env_index.squeeze(1)
            emb = self.emb(env_index)
        output = self.trunk(emb)
        gate = self._softmax(output / self.temperature)
        if not self.should_use_soft_attention:
            topk_attention = gate.topk(self.topk, dim=1)
            topk_attention_indices = topk_attention[1]
            hard_attention_mask = torch.zeros_like(gate).scatter_(
                dim=1, index=topk_attention_indices, value=1.0
            )
            gate = gate * hard_attention_mask
            gate = gate / gate.sum(dim=1).unsqueeze(1)
        if len(gate.shape) > 2:
            breakpoint()
        return gate.t().unsqueeze(2)