in TransferQA/T5.py [0:0]
def SGD_predict(args, tokenizer, model, save_path, save_folder="sgd_prediction"):
if not os.path.exists(os.path.join(save_path, save_folder)):
os.makedirs(os.path.join(save_path, save_folder))
test_loader, sgd_data = prepare_SGD_data(args, tokenizer)
multi_choices_collection = []
# to gpu
device = torch.device("cuda:0")
model.to(device)
model.eval()
# delete all the gold slot values for testing
for dial in sgd_data:
for turn in dial["turns"]:
if turn["speaker"] == "USER":
for frame in turn["frames"]:
frame["state"]["slot_values"] = {}
for batch in tqdm(test_loader):
dst_outputs = model.generate(input_ids=batch["encoder_input"].to(device),
attention_mask=batch["attention_mask"].to(device),
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
value_batch = tokenizer.batch_decode(dst_outputs, skip_special_tokens=True)
for idx, value in enumerate(value_batch):
dial_id = batch["ID"][idx]
turn_id = batch["turn_id"][idx]
frame_id = batch["frame_id"][idx]
slot_key = batch["slot_text"][idx]
# double check
assert sgd_data[dial_id]["dialogue_id"]==batch["dialogue_id"][idx]
if batch["question_type"][idx]=="extractive" and value!="none":
sgd_data[dial_id]["turns"][turn_id]["frames"][frame_id]["state"]["slot_values"][slot_key] = [value]
# collect multi-choice answers
if batch["question_type"][idx]=="multi-choice":
multi_choices_collection.append({"dial_id":dial_id, "turn_id":turn_id, "frame_id":frame_id, "slot_key":slot_key, "value":[value]})
# update the extractive prediction with multi-choice prediction
for example in multi_choices_collection:
if example["slot_key"] in sgd_data[example["dial_id"]]["turns"][example["turn_id"]]["frames"][example["frame_id"]]["state"]["slot_values"]:
sgd_data[example["dial_id"]]["turns"][example["turn_id"]]["frames"][example["frame_id"]]["state"]["slot_values"][example["slot_key"]] = example["value"]
with open(os.path.join(save_path, save_folder,"output.json"), 'w') as fout:
json.dump(sgd_data, fout, indent=4)