in TransferQA/T5.py [0:0]
def evaluate_model(args, tokenizer, model, test_loader, save_path, ALL_SLOTS, prefix="zeroshot"):
prefix += ("use_value_" + str(args["use_value"]))
save_path = os.path.join(save_path,"results")
if not os.path.exists(save_path):
os.makedirs(save_path)
predictions = {}
multi_choices_collection = []
# active_slot_collection = {}
# to gpu
# gpu = args["GPU"][0]
device = torch.device("cuda:0")
model.to(device)
model.eval()
if args["canonicalization"]:
if args["version"]=="2.0":
ontology = normalize_ontology(json.load(open("data/mwz2.0/ontology.json", 'r')))
else:
ontology = normalize_ontology(json.load(open("data/mwz2.1/ontology.json", 'r')))
# with open("data/ontology.json") as f:
# ontology = json.load(f)
slot_logger = {slot_name:[0,0,0] for slot_name in ALL_SLOTS}
slot_logger["slot_gate"] = [0,0,0]
for batch in tqdm(test_loader):
dst_outputs = model.generate(input_ids=batch["encoder_input"].to(device),
attention_mask=batch["attention_mask"].to(device),
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
value_batch = tokenizer.batch_decode(dst_outputs, skip_special_tokens=True)
for idx, value in enumerate(value_batch):
dial_id = batch["ID"][idx]
if dial_id not in predictions:
predictions[dial_id] = {}
predictions[dial_id]["domain"] = batch["domains"][idx][0]
predictions[dial_id]["turns"] = {}
if batch["turn_id"][idx] not in predictions[dial_id]["turns"]:
predictions[dial_id]["turns"][batch["turn_id"][idx]] = {"turn_belief":batch["turn_belief"][idx], "pred_belief":[]}
# add the active slots into the collection
if batch["question_type"][idx]=="extractive" and value!="none":
if args["canonicalization"]:
value = difflib.get_close_matches(value, ontology[batch["slot_text"][idx]], n=1)
if len(value)>0:
predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(str(batch["slot_text"][idx])+'-'+str(value[0]))
value = value[0]
else:
value="none"
else:
predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(str(batch["slot_text"][idx])+'-'+value)
# analyze none acc:
if batch["question_type"][idx]=="extractive":
if value=="none" and batch["value_text"][idx]=="none":
slot_logger["slot_gate"][1]+=1 # hit
if value!="none" and batch["value_text"][idx]!="none":
slot_logger["slot_gate"][1]+=1 # hit
slot_logger["slot_gate"][0]+=1 # total
# collect multi-choice answers
if batch["question_type"][idx]=="multi-choice":
if args["canonicalization"]:
value = difflib.get_close_matches(value, ontology[batch["slot_text"][idx]], n=1)
if len(value)>0 and value!="":
value = value[0]
else:
value="none"
multi_choices_collection.append({"dial_id":batch["ID"][idx], "turn_id":batch["turn_id"][idx], "slot_text":batch["slot_text"][idx], "value":value})
# ["day","type","area","pricerange",'internet',"parking"]
# analyze slot acc:
if (batch["value_text"][idx]!="none"):
if str(value)==str(batch["value_text"][idx]):
slot_logger[str(batch["slot_text"][idx])][1]+=1 # hit
slot_logger[str(batch["slot_text"][idx])][0]+=1 # total
for example in multi_choices_collection:
dial_id = example["dial_id"]
turn_id = example["turn_id"]
extractive_value = ""
# check active slot
for kv in predictions[dial_id]["turns"][turn_id]["pred_belief"]:
if example["slot_text"] in kv:
extractive_value = kv
# if slot is not active
if extractive_value=="":
continue
# replace extrative slot with multi-choice
predictions[dial_id]["turns"][turn_id]["pred_belief"].remove(extractive_value)
predictions[dial_id]["turns"][turn_id]["pred_belief"].append(str(example["slot_text"])+'-'+str(example["value"]))
for slot_log in slot_logger.values():
slot_log[2] = slot_log[1]/slot_log[0]
with open(os.path.join(save_path, f"{prefix}_slot_acc.json"), 'w') as f:
json.dump(slot_logger,f, indent=4)
# with open(os.path.join(save_path, f"{prefix}_activation_collection.json"), 'w') as f:
# json.dump(active_slot_collection,f, indent=4)
with open(os.path.join(save_path, f"{prefix}_prediction.json"), 'w') as f:
json.dump(predictions,f, indent=4)
joint_acc_score, F1_score, turn_acc_score = evaluate_metrics(predictions, ALL_SLOTS)
evaluation_metrics = {"Joint Acc":joint_acc_score, "Turn Acc":turn_acc_score, "Joint F1":F1_score}
print(f"{prefix} result:",evaluation_metrics)
with open(os.path.join(save_path, f"{prefix}_result.json"), 'w') as f:
json.dump(evaluation_metrics,f, indent=4)
return predictions