def eval_single_token_prediction()

in custom/evaluate_utils.py [0:0]


def eval_single_token_prediction(model, itr, dictionary, singletoken_topp=0.0, singletoken_topk=1):
    predicted_tokens = []
    target_tokens = []

    mle_loss_sum = 0
    num_samples_sum = 0
    wrong_mass_sum = 0

    logging_outputs = []

    for n, sample in tqdm(enumerate(itr)):
        sample = utils.move_to_cuda(sample)
        net_output = model(**sample['net_input'])
        logits = net_output[0][0]
        logits[:, dictionary.pad()] = -1e19
        predicted_tokens.append(logits.argmax(1).tolist())
        target = sample['target'].view(-1)
        target_tokens.append(target.tolist())

        # -- mle loss
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        true_token_lprobs = F.nll_loss(
            lprobs,
            target,
            ignore_index=dictionary.pad_index,
            reduction='none',
        )
        true_token_logits = -F.nll_loss(
            logits,
            target,
            ignore_index=dictionary.pad_index,
            reduction='none',
        )
        mle_loss = true_token_lprobs.sum()
        orig = utils.strip_pad(target, dictionary.pad_index)
        ntokens = orig.numel()

        mle_loss_sum += mle_loss.item()
        num_samples_sum += ntokens

        logging_output = TrainingMetrics.ranking_metrics(logits, true_token_logits, sample, ntokens, target, topk=singletoken_topk, topp=singletoken_topp)

        negative_targets = (logits > true_token_logits[:, None]).float()
        wrong_mass_sum += (negative_targets * (F.softmax(logits, dim=1))).sum()

        logging_outputs.append(logging_output)

    ppl = math.pow(2, mle_loss_sum / num_samples_sum / math.log(2))
    custom_metrics = TrainingMetrics.aggregate_and_normalize(logging_outputs)
    custom_metrics['ppl'] = ppl
    avg_wrong_mass = wrong_mass_sum / num_samples_sum
    custom_metrics['avg_wrong_mass'] = avg_wrong_mass.item()
    return predicted_tokens, target_tokens, custom_metrics