def SGD_predict()

in TransferQA/T5.py [0:0]


def SGD_predict(args, tokenizer, model, save_path, save_folder="sgd_prediction"):
    if not os.path.exists(os.path.join(save_path, save_folder)):
        os.makedirs(os.path.join(save_path, save_folder))

    test_loader, sgd_data = prepare_SGD_data(args, tokenizer)
    multi_choices_collection = []
    # to gpu
    device = torch.device("cuda:0")
    model.to(device)
    model.eval()

    # delete all the gold slot values for testing
    for dial in sgd_data:
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                for frame in turn["frames"]:
                    frame["state"]["slot_values"] = {}

    for batch in tqdm(test_loader):
        dst_outputs = model.generate(input_ids=batch["encoder_input"].to(device),
                                attention_mask=batch["attention_mask"].to(device),
                                eos_token_id=tokenizer.eos_token_id,
                                max_length=200,
                                )

        value_batch = tokenizer.batch_decode(dst_outputs, skip_special_tokens=True)
        for idx, value in enumerate(value_batch):
            dial_id = batch["ID"][idx]
            turn_id = batch["turn_id"][idx]
            frame_id = batch["frame_id"][idx]
            slot_key = batch["slot_text"][idx]
            # double check
            assert sgd_data[dial_id]["dialogue_id"]==batch["dialogue_id"][idx]
            if batch["question_type"][idx]=="extractive" and value!="none":
                sgd_data[dial_id]["turns"][turn_id]["frames"][frame_id]["state"]["slot_values"][slot_key] = [value]

            # collect multi-choice answers
            if batch["question_type"][idx]=="multi-choice":
                multi_choices_collection.append({"dial_id":dial_id, "turn_id":turn_id, "frame_id":frame_id, "slot_key":slot_key, "value":[value]})

    # update the extractive prediction with multi-choice prediction
    for example in multi_choices_collection:
        if example["slot_key"] in sgd_data[example["dial_id"]]["turns"][example["turn_id"]]["frames"][example["frame_id"]]["state"]["slot_values"]:
            sgd_data[example["dial_id"]]["turns"][example["turn_id"]]["frames"][example["frame_id"]]["state"]["slot_values"][example["slot_key"]] = example["value"]


    with open(os.path.join(save_path, save_folder,"output.json"), 'w') as fout:
        json.dump(sgd_data, fout, indent=4)