in scripts/batch_eval_KB_completion.py [0:0]
def run_thread(arguments):
msg = ""
# 1. compute the ranking metrics on the filtered log_probs tensor
sample_MRR, sample_P, experiment_result, return_msg = metrics.get_ranking(
arguments["filtered_log_probs"],
arguments["masked_indices"],
arguments["vocab"],
label_index=arguments["label_index"],
index_list=arguments["index_list"],
print_generation=arguments["interactive"],
topk=10000,
)
msg += "\n" + return_msg
sample_perplexity = 0.0
if arguments["interactive"]:
pprint(arguments["sample"])
# THIS IS OPTIONAL - mainly used for debuggind reason
# 2. compute perplexity and print predictions for the complete log_probs tensor
sample_perplexity, return_msg = print_sentence_predictions(
arguments["original_log_probs"],
arguments["token_ids"],
arguments["vocab"],
masked_indices=arguments["masked_indices"],
print_generation=arguments["interactive"],
)
input("press enter to continue...")
msg += "\n" + return_msg
return experiment_result, sample_MRR, sample_P, sample_perplexity, msg