def _compute_da_loss()

in curiosity/baseline_models.py [0:0]


    def _compute_da_loss(self,
                         output_dict,
                         messages: torch.Tensor,
                         shifted_context: torch.Tensor,
                         utter_mask: torch.Tensor,
                         dialog_acts: torch.Tensor,
                         dialog_acts_mask: torch.Tensor):
        """
        Given utterance at turn t, get the context (utter + acts) from t-1,
        the utter_t, and predict the act
        """
        message_w_context = torch.cat((
            messages, shifted_context
        ), dim=-1)

        # (batch_size, n_turns, n_dialog_acts)
        da_logits = self._da_classifier(message_w_context)
        output_dict['da_logits'] = da_logits
        da_unreduced_loss = self._da_bce_loss(da_logits, dialog_acts.float())
        # Note: the last dimension is expanded from singleton to n_dialog_acts
        # Since the mask is at turn level
        # (batch_size, n_turns, n_dialog_acts)
        da_combined_mask = (
            dialog_acts_mask.float().unsqueeze(-1)
            * utter_mask.float().unsqueeze(-1)
        ).expand_as(da_unreduced_loss)
        da_unreduced_loss = da_combined_mask * da_unreduced_loss
        # Mean loss over non-masked inputs, avoid division by zero
        da_loss = da_unreduced_loss.sum() / da_combined_mask.sum().clamp(min=1)
        da_loss_item = da_loss.item()
        self._da_loss_metric(da_loss_item)
        # (batch_size, n_turns, n_dialog_acts)
        da_preds = (torch.sigmoid(da_logits) > .5).long()
        self._da_f1_metric(da_preds, dialog_acts, da_combined_mask.long())
        return da_loss