def compute_loss()

in dpr/models/reader.py [0:0]


def compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M):
    start_positions = start_positions.view(N * M, -1)
    end_positions = end_positions.view(N * M, -1)
    answer_mask = answer_mask.view(N * M, -1)

    start_logits = start_logits.view(N * M, -1)
    end_logits = end_logits.view(N * M, -1)
    relevance_logits = relevance_logits.view(N * M)

    answer_mask = answer_mask.type(torch.FloatTensor).cuda()

    ignored_index = start_logits.size(1)
    start_positions.clamp_(0, ignored_index)
    end_positions.clamp_(0, ignored_index)
    loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index)

    # compute switch loss
    relevance_logits = relevance_logits.view(N, M)
    switch_labels = torch.zeros(N, dtype=torch.long).cuda()
    switch_loss = torch.sum(loss_fct(relevance_logits, switch_labels))

    # compute span loss
    start_losses = [
        (loss_fct(start_logits, _start_positions) * _span_mask)
        for (_start_positions, _span_mask) in zip(
            torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1)
        )
    ]

    end_losses = [
        (loss_fct(end_logits, _end_positions) * _span_mask)
        for (_end_positions, _span_mask) in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1))
    ]
    loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat(
        [t.unsqueeze(1) for t in end_losses], dim=1
    )

    loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0]
    span_loss = _calc_mml(loss_tensor)
    return span_loss + switch_loss