def main()

in model/disambiguate/format_disambiguation_data.py [0:0]


def main(args):
    for split in SPLITS:
        read_path = args[f"simmc_{split}_json"]
        print(f"Reading: {read_path}")
        with open(read_path, "r") as file_id:
            dialogs = json.load(file_id)

        # Reformat into simple strings with positive and negative labels.
        # (dialog string, label)
        disambiguate_data = []
        for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
            history = []
            for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
                history.append(turn_datum["transcript"])

                if "disambiguation_label" in turn_datum:
                    label = turn_datum["disambiguation_label"]
                    new_datum = {
                        "dialog_id": dialog_datum["dialogue_idx"],
                        "turn_id": turn_ind,
                        "input_text": copy.deepcopy(history),
                        "disambiguation_label_gt": label,
                    }
                    disambiguate_data.append(new_datum)
                # Ignore if system_transcript is not found (last round teststd).
                if turn_datum.get("system_transcript", None):
                    history.append(turn_datum["system_transcript"])
        print(f"# instances [{split}]: {len(disambiguate_data)}")
        save_path = os.path.join(
            args["disambiguate_save_path"], f"simmc2_disambiguate_dstc10_{split}.json"
        )
        print(f"Saving: {save_path}")
        with open(save_path, "w") as file_id:
            json.dump(disambiguate_data, file_id)