in lama/utils.py [0:0]
def print_sentence_predictions(log_probs, token_ids, vocab,
masked_indices=None, print_generation=True,
topk=1000):
msg = "\n"
log_probs = log_probs[:len(token_ids)]
value_max_probs, index_max_probs = __get_topk(log_probs, topk)
# remove special symbols from token_ids
excluded_indices = __exclude_tokens([t for t in token_ids], vocab)
# score only first mask
masked_indices = masked_indices[:1]
tokens = torch.from_numpy(np.asarray(token_ids))
# get ranking position in topk
query = tokens.squeeze().data.unsqueeze(-1)
query = query.repeat(1, topk)
ranking_position = (index_max_probs == query.numpy()).nonzero()
rank_dict = dict(zip(*ranking_position))
# get positional score of the correct token
token_probs = log_probs.gather(
dim=1,
index=tokens.view(-1, 1),
)
positional_scores = token_probs.squeeze(-1).detach().numpy()
score_sum = 0.
count = 0
for idx, score in enumerate(positional_scores):
if idx not in excluded_indices:
score_sum += score
count += 1
if count > 0:
avg_nll_loss = - (score_sum / count)
else:
avg_nll_loss = 0.0
perplexity = np.exp(avg_nll_loss)
# print("positional_scores: {}".format(positional_scores))
# print("avg_nll_loss: {}".format(avg_nll_loss))
__print_generation(positional_scores, token_ids, vocab, rank_dict,
index_max_probs, value_max_probs, topk,
excluded_indices, masked_indices, print_generation)
# msg += return_msg
msg += '| Perplexity: {:.3f}\n'.format(perplexity)
if print_generation:
print("\n"+msg+"\n")
return perplexity, msg