in lama/utils.py [0:0]
def __print_generation(positional_scores, token_ids, vocab, rank_dict,
index_max_probs, value_max_probs, topk,
indices_to_exclude, masked_indices, print_on_console):
init() # colorful output
msg = ""
dash = '-' * 82
msg += dash + "\n"
msg += '{:<8s}{:<20s}{:<12s}{:<20}{:<12s}{:<12s}'.format(
"index", "token", "log_prob", "prediction",
"log_prob", "rank@{}".format(topk))
msg += "\n" + dash
if print_on_console:
print(msg)
msg += '\n'
for idx, tok in enumerate(token_ids):
word_form = vocab[tok]
rank = -1
if idx in rank_dict:
rank = rank_dict[idx]
index_max_prob = index_max_probs[idx]
predicted_token_id = index_max_prob[0]
value_max_prob = value_max_probs[idx]
string_to_print = '{:<8d}{:<20s}{:<12.3f}{:<20s}{:<12.3f}{:<12d}'.format(
idx,
str(word_form),
positional_scores[idx],
str(vocab[predicted_token_id]),
value_max_prob[0],
rank
)
if print_on_console:
if masked_indices is not None and idx in masked_indices:
print(colored(string_to_print, 'grey', 'on_yellow'))
elif indices_to_exclude is not None and idx in indices_to_exclude:
print(colored(string_to_print, 'grey', 'on_grey'))
else:
print(string_to_print)
msg += string_to_print + "\n"
return msg