in model/disambiguate/format_disambiguation_data.py [0:0]
def main(args):
for split in SPLITS:
read_path = args[f"simmc_{split}_json"]
print(f"Reading: {read_path}")
with open(read_path, "r") as file_id:
dialogs = json.load(file_id)
# Reformat into simple strings with positive and negative labels.
# (dialog string, label)
disambiguate_data = []
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
history = []
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
history.append(turn_datum["transcript"])
if "disambiguation_label" in turn_datum:
label = turn_datum["disambiguation_label"]
new_datum = {
"dialog_id": dialog_datum["dialogue_idx"],
"turn_id": turn_ind,
"input_text": copy.deepcopy(history),
"disambiguation_label_gt": label,
}
disambiguate_data.append(new_datum)
# Ignore if system_transcript is not found (last round teststd).
if turn_datum.get("system_transcript", None):
history.append(turn_datum["system_transcript"])
print(f"# instances [{split}]: {len(disambiguate_data)}")
save_path = os.path.join(
args["disambiguate_save_path"], f"simmc2_disambiguate_dstc10_{split}.json"
)
print(f"Saving: {save_path}")
with open(save_path, "w") as file_id:
json.dump(disambiguate_data, file_id)