def main()

in model/mm_dst/gpt2_dst/scripts/format_retrieval_results.py [0:0]


def main(args):
    print(f"""Reading dialogs: {args["dialog_json_file"]}""")
    with open(args["dialog_json_file"], "r") as file_id:
        dialogs = json.load(file_id)

    print(f"""Reading outputs: {args["model_output_file"]}""")
    with open(args["model_output_file"], "r") as file_id:
        scores = [float(ii) for ii in file_id.readlines()]

    # Number of scores should match number of instances.
    num_turns = sum(len(ii["dialogue"]) for ii in dialogs["dialogue_data"])
    assert len(scores) == NUM_OPTIONS * num_turns, "#turns do not match!"

    formatted_result = []
    num_turns = 0
    for dialog_datum in dialogs["dialogue_data"]:
        dialog_id = dialog_datum["dialogue_idx"]
        new_entry = {"dialog_id": dialog_id, "candidate_scores": []}
        for turn_id, turn_datum in enumerate(dialog_datum["dialogue"]):
            start_ind = num_turns * NUM_OPTIONS
            end_ind = (num_turns + 1) * NUM_OPTIONS

            # Scores are NLL, lower is better, hence -1.
            new_turn_entry = {
                "turn_id": turn_id,
                "scores": [-1 * ii for ii in scores[start_ind:end_ind]],
            }
            num_turns += 1
            new_entry["candidate_scores"].append(new_turn_entry)
        formatted_result.append(new_entry)

    # Write the result back.
    print(f"""Saving: {args["formatted_output_file"]}""")
    with open(args["formatted_output_file"], "w") as file_id:
        json.dump(formatted_result, file_id)