in mdr/qa/qa_model.py [0:0]
def forward(self, batch):
outputs = self.encoder(batch['input_ids'], batch['attention_mask'], batch.get('token_type_ids', None))
if "electra" in self.model_name:
sequence_output = outputs[0]
pooled_output = self.pooler(sequence_output)
else:
sequence_output, pooled_output = outputs[0], outputs[1]
logits = self.qa_outputs(sequence_output)
outs = [o.squeeze(-1) for o in logits.split(1, dim=-1)]
outs = [o.float().masked_fill(batch["paragraph_mask"].ne(1), float("-inf")).type_as(o) for o in outs]
start_logits, end_logits = outs[0], outs[1]
rank_score = self.rank(pooled_output)
if self.sp_pred:
gather_index = batch["sent_offsets"].unsqueeze(2).expand(-1, -1, sequence_output.size()[-1])
sent_marker_rep = torch.gather(sequence_output, 1, gather_index)
sp_score = self.sp(sent_marker_rep).squeeze(2)
else:
sp_score = None
if self.training:
rank_target = batch["label"]
if self.sp_pred:
sp_loss = F.binary_cross_entropy_with_logits(sp_score, batch["sent_labels"].float(), reduction="none")
sp_loss = (sp_loss * batch["sent_offsets"]) * batch["label"]
sp_loss = sp_loss.sum()
start_positions, end_positions = batch["starts"], batch["ends"]
rank_loss = F.binary_cross_entropy_with_logits(rank_score, rank_target.float(), reduction="sum")
start_losses = [self.loss_fct(start_logits, starts) for starts in torch.unbind(start_positions, dim=1)]
end_losses = [self.loss_fct(end_logits, ends) for ends in torch.unbind(end_positions, dim=1)]
loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1)
log_prob = - loss_tensor
log_prob = log_prob.float().masked_fill(log_prob == 0, float('-inf')).type_as(log_prob)
marginal_probs = torch.sum(torch.exp(log_prob), dim=1)
m_prob = [marginal_probs[idx] for idx in marginal_probs.nonzero()]
if len(m_prob) == 0:
span_loss = self.loss_fct(start_logits, start_logits.new_zeros(
start_logits.size(0)).long()-1).sum()
else:
span_loss = - torch.log(torch.cat(m_prob)).sum()
if self.sp_pred:
loss = rank_loss + span_loss + sp_loss * self.sp_weight
else:
loss = rank_loss + span_loss
return loss.unsqueeze(0)
return {
'start_logits': start_logits,
'end_logits': end_logits,
'rank_score': rank_score,
"sp_score": sp_score
}