in model/mm_dst/gpt2_dst/scripts/format_retrieval_results.py [0:0]
def main(args):
print(f"""Reading dialogs: {args["dialog_json_file"]}""")
with open(args["dialog_json_file"], "r") as file_id:
dialogs = json.load(file_id)
print(f"""Reading outputs: {args["model_output_file"]}""")
with open(args["model_output_file"], "r") as file_id:
scores = [float(ii) for ii in file_id.readlines()]
# Number of scores should match number of instances.
num_turns = sum(len(ii["dialogue"]) for ii in dialogs["dialogue_data"])
assert len(scores) == NUM_OPTIONS * num_turns, "#turns do not match!"
formatted_result = []
num_turns = 0
for dialog_datum in dialogs["dialogue_data"]:
dialog_id = dialog_datum["dialogue_idx"]
new_entry = {"dialog_id": dialog_id, "candidate_scores": []}
for turn_id, turn_datum in enumerate(dialog_datum["dialogue"]):
start_ind = num_turns * NUM_OPTIONS
end_ind = (num_turns + 1) * NUM_OPTIONS
# Scores are NLL, lower is better, hence -1.
new_turn_entry = {
"turn_id": turn_id,
"scores": [-1 * ii for ii in scores[start_ind:end_ind]],
}
num_turns += 1
new_entry["candidate_scores"].append(new_turn_entry)
formatted_result.append(new_entry)
# Write the result back.
print(f"""Saving: {args["formatted_output_file"]}""")
with open(args["formatted_output_file"], "w") as file_id:
json.dump(formatted_result, file_id)