in lib/action_head.py [0:0]
def forward(self, input_data: torch.Tensor, mask=None) -> Any:
if self.linear_layer is not None:
flat_out = self.linear_layer(input_data)
else:
flat_out = input_data
shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape)
shaped_out /= self.temperature
if mask is not None:
shaped_out[~mask] = LOG0
# Convert to float32 to avoid RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half'
return F.log_softmax(shaped_out.float(), dim=-1)