def evaluate_model()

in TransferQA/T5.py [0:0]


def evaluate_model(args, tokenizer, model, test_loader, save_path, ALL_SLOTS, prefix="zeroshot"):
    prefix += ("use_value_" + str(args["use_value"]))
    save_path = os.path.join(save_path,"results")
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    predictions = {}
    multi_choices_collection = []
    # active_slot_collection = {}
    # to gpu
    # gpu = args["GPU"][0]
    device = torch.device("cuda:0")
    model.to(device)
    model.eval()
    if args["canonicalization"]:
        if args["version"]=="2.0":
            ontology = normalize_ontology(json.load(open("data/mwz2.0/ontology.json", 'r')))
        else:
            ontology = normalize_ontology(json.load(open("data/mwz2.1/ontology.json", 'r')))
            # with open("data/ontology.json") as f:
            #     ontology = json.load(f)
    slot_logger = {slot_name:[0,0,0] for slot_name in ALL_SLOTS}
    slot_logger["slot_gate"] = [0,0,0]

    for batch in tqdm(test_loader):
        dst_outputs = model.generate(input_ids=batch["encoder_input"].to(device),
                                attention_mask=batch["attention_mask"].to(device),
                                eos_token_id=tokenizer.eos_token_id,
                                max_length=200,
                                )

        value_batch = tokenizer.batch_decode(dst_outputs, skip_special_tokens=True)

        for idx, value in enumerate(value_batch):
            dial_id = batch["ID"][idx]
            if dial_id not in predictions:
                predictions[dial_id] = {}
                predictions[dial_id]["domain"] = batch["domains"][idx][0]
                predictions[dial_id]["turns"] = {}

            if batch["turn_id"][idx] not in predictions[dial_id]["turns"]:
                predictions[dial_id]["turns"][batch["turn_id"][idx]] = {"turn_belief":batch["turn_belief"][idx], "pred_belief":[]}

            # add the active slots into the collection
            if batch["question_type"][idx]=="extractive" and value!="none":

                if args["canonicalization"]:
                    value = difflib.get_close_matches(value, ontology[batch["slot_text"][idx]], n=1)
                    if len(value)>0:
                        predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(str(batch["slot_text"][idx])+'-'+str(value[0]))
                        value = value[0]
                    else:
                        value="none"
                else:
                    predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(str(batch["slot_text"][idx])+'-'+value)
            # analyze none acc:
            if batch["question_type"][idx]=="extractive":
                if value=="none" and batch["value_text"][idx]=="none":
                    slot_logger["slot_gate"][1]+=1 # hit
                if value!="none" and batch["value_text"][idx]!="none":
                    slot_logger["slot_gate"][1]+=1 # hit
                slot_logger["slot_gate"][0]+=1 # total

            # collect multi-choice answers
            if batch["question_type"][idx]=="multi-choice":
                if args["canonicalization"]:
                    value = difflib.get_close_matches(value, ontology[batch["slot_text"][idx]], n=1)
                    if len(value)>0 and value!="":
                        value = value[0]
                    else:
                        value="none"
                multi_choices_collection.append({"dial_id":batch["ID"][idx], "turn_id":batch["turn_id"][idx], "slot_text":batch["slot_text"][idx], "value":value})
            # ["day","type","area","pricerange",'internet',"parking"]
            # analyze slot acc:
            if (batch["value_text"][idx]!="none"):
                if str(value)==str(batch["value_text"][idx]):
                    slot_logger[str(batch["slot_text"][idx])][1]+=1 # hit
                slot_logger[str(batch["slot_text"][idx])][0]+=1 # total

    for example in multi_choices_collection:
        dial_id = example["dial_id"]
        turn_id = example["turn_id"]
        extractive_value = ""
        # check active slot
        for kv in predictions[dial_id]["turns"][turn_id]["pred_belief"]:
            if example["slot_text"] in kv:
                extractive_value = kv
        # if slot is not active
        if extractive_value=="":
            continue
        # replace extrative slot with multi-choice
        predictions[dial_id]["turns"][turn_id]["pred_belief"].remove(extractive_value)
        predictions[dial_id]["turns"][turn_id]["pred_belief"].append(str(example["slot_text"])+'-'+str(example["value"]))


    for slot_log in slot_logger.values():
        slot_log[2] = slot_log[1]/slot_log[0]

    with open(os.path.join(save_path, f"{prefix}_slot_acc.json"), 'w') as f:
        json.dump(slot_logger,f, indent=4)

    # with open(os.path.join(save_path, f"{prefix}_activation_collection.json"), 'w') as f:
    #     json.dump(active_slot_collection,f, indent=4)

    with open(os.path.join(save_path, f"{prefix}_prediction.json"), 'w') as f:
        json.dump(predictions,f, indent=4)

    joint_acc_score, F1_score, turn_acc_score = evaluate_metrics(predictions, ALL_SLOTS)

    evaluation_metrics = {"Joint Acc":joint_acc_score, "Turn Acc":turn_acc_score, "Joint F1":F1_score}
    print(f"{prefix} result:",evaluation_metrics)

    with open(os.path.join(save_path, f"{prefix}_result.json"), 'w') as f:
        json.dump(evaluation_metrics,f, indent=4)

    return predictions