in curiosity/baseline_models.py [0:0]
def _compute_da_loss(self,
output_dict,
messages: torch.Tensor,
shifted_context: torch.Tensor,
utter_mask: torch.Tensor,
dialog_acts: torch.Tensor,
dialog_acts_mask: torch.Tensor):
"""
Given utterance at turn t, get the context (utter + acts) from t-1,
the utter_t, and predict the act
"""
message_w_context = torch.cat((
messages, shifted_context
), dim=-1)
# (batch_size, n_turns, n_dialog_acts)
da_logits = self._da_classifier(message_w_context)
output_dict['da_logits'] = da_logits
da_unreduced_loss = self._da_bce_loss(da_logits, dialog_acts.float())
# Note: the last dimension is expanded from singleton to n_dialog_acts
# Since the mask is at turn level
# (batch_size, n_turns, n_dialog_acts)
da_combined_mask = (
dialog_acts_mask.float().unsqueeze(-1)
* utter_mask.float().unsqueeze(-1)
).expand_as(da_unreduced_loss)
da_unreduced_loss = da_combined_mask * da_unreduced_loss
# Mean loss over non-masked inputs, avoid division by zero
da_loss = da_unreduced_loss.sum() / da_combined_mask.sum().clamp(min=1)
da_loss_item = da_loss.item()
self._da_loss_metric(da_loss_item)
# (batch_size, n_turns, n_dialog_acts)
da_preds = (torch.sigmoid(da_logits) > .5).long()
self._da_f1_metric(da_preds, dialog_acts, da_combined_mask.long())
return da_loss