def predict_batch()

in recipes/sota/2019/rescoring/forward_lm.py [0:0]


def predict_batch(sentences, model, fairseq_dict, max_len):
    encoded_input = []
    padded_input = []
    ppls = []

    total_loss = 0.0
    nwords = 0
    for sentence in sentences:
        encoded_input.append([fairseq_dict.index(token) for token in sentence])
        assert (
            len(encoded_input[-1]) <= max_len
        ), "Error in the input length, it should be less than max_len {}".format(
            max_len
        )
        if len(encoded_input[-1]) < max_len:
            padded_input.append(
                [fairseq_dict.eos()]
                + encoded_input[-1]
                + [fairseq_dict.eos()] * (max_len - len(encoded_input[-1]))
            )
        else:
            padded_input.append([fairseq_dict.eos()] + encoded_input[-1])
    x = torch.LongTensor(padded_input).cuda()
    with torch.no_grad():
        y = model.forward(x)[0]
        if model.adaptive_softmax is not None:
            logprobs = (
                model.adaptive_softmax.get_log_prob(y, None).detach().cpu().numpy()
            )
        else:
            logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()

    for index, input_i in enumerate(encoded_input):
        loss = numpy.sum(logprobs[index, numpy.arange(len(input_i)), input_i])
        loss += logprobs[index, len(input_i), fairseq_dict.eos()]
        ppls.append(loss)

        total_loss += loss
        nwords += len(input_i) + 1
    return ppls, total_loss, nwords