in dpr/models/reader.py [0:0]
def compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M):
start_positions = start_positions.view(N * M, -1)
end_positions = end_positions.view(N * M, -1)
answer_mask = answer_mask.view(N * M, -1)
start_logits = start_logits.view(N * M, -1)
end_logits = end_logits.view(N * M, -1)
relevance_logits = relevance_logits.view(N * M)
answer_mask = answer_mask.type(torch.FloatTensor).cuda()
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index)
# compute switch loss
relevance_logits = relevance_logits.view(N, M)
switch_labels = torch.zeros(N, dtype=torch.long).cuda()
switch_loss = torch.sum(loss_fct(relevance_logits, switch_labels))
# compute span loss
start_losses = [
(loss_fct(start_logits, _start_positions) * _span_mask)
for (_start_positions, _span_mask) in zip(
torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1)
)
]
end_losses = [
(loss_fct(end_logits, _end_positions) * _span_mask)
for (_end_positions, _span_mask) in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1))
]
loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat(
[t.unsqueeze(1) for t in end_losses], dim=1
)
loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0]
span_loss = _calc_mml(loss_tensor)
return span_loss + switch_loss