def print_sentence_predictions()

in lama/utils.py [0:0]


def print_sentence_predictions(log_probs, token_ids, vocab,
                               masked_indices=None, print_generation=True,
                               topk=1000):

    msg = "\n"
    log_probs = log_probs[:len(token_ids)]
    value_max_probs, index_max_probs = __get_topk(log_probs, topk)

    # remove special symbols from token_ids
    excluded_indices = __exclude_tokens([t for t in token_ids], vocab)

    # score only first mask
    masked_indices = masked_indices[:1]

    tokens = torch.from_numpy(np.asarray(token_ids))

    # get ranking position in topk
    query = tokens.squeeze().data.unsqueeze(-1)
    query = query.repeat(1, topk)

    ranking_position = (index_max_probs == query.numpy()).nonzero()

    rank_dict = dict(zip(*ranking_position))

    # get positional score of the correct token
    token_probs = log_probs.gather(
        dim=1,
        index=tokens.view(-1, 1),
    )
    positional_scores = token_probs.squeeze(-1).detach().numpy()

    score_sum = 0.
    count = 0
    for idx, score in enumerate(positional_scores):
        if idx not in excluded_indices:
            score_sum += score
            count += 1

    if count > 0:
        avg_nll_loss = - (score_sum / count)
    else:
        avg_nll_loss = 0.0
    perplexity = np.exp(avg_nll_loss)

    # print("positional_scores: {}".format(positional_scores))
    # print("avg_nll_loss: {}".format(avg_nll_loss))

    __print_generation(positional_scores, token_ids, vocab, rank_dict,
                       index_max_probs, value_max_probs, topk,
                       excluded_indices, masked_indices, print_generation)

    # msg += return_msg
    msg += '| Perplexity: {:.3f}\n'.format(perplexity)

    if print_generation:
        print("\n"+msg+"\n")

    return perplexity, msg