in mtrl/agent/components/moe_layer.py [0:0]
def forward(self, task_info: TaskInfo) -> TensorType:
if self.should_use_task_encoding:
emb = task_info.encoding
if self.should_detach_task_encoding:
emb = emb.detach() # type: ignore[union-attr]
else:
env_index = task_info.env_index
if len(env_index.shape) == 2:
env_index = env_index.squeeze(1)
emb = self.emb(env_index)
output = self.trunk(emb)
gate = self._softmax(output / self.temperature)
if not self.should_use_soft_attention:
topk_attention = gate.topk(self.topk, dim=1)
topk_attention_indices = topk_attention[1]
hard_attention_mask = torch.zeros_like(gate).scatter_(
dim=1, index=topk_attention_indices, value=1.0
)
gate = gate * hard_attention_mask
gate = gate / gate.sum(dim=1).unsqueeze(1)
if len(gate.shape) > 2:
breakpoint()
return gate.t().unsqueeze(2)