def test_eval()

in scripts/adapet/ADAPET/src/eval/eval_model.py [0:0]


def test_eval(config, model, batcher):
    '''
    Evaluates the accuracy on the test partition

    :param config:
    :param model:
    :param batcher:
    '''

    model.eval()
    dataset_reader = batcher.get_dataset_reader()
    test_writer = Writer(os.path.join(config.exp_dir, "test.json"), dataset_reader)

    with torch.no_grad():
        #added 
        pred_labels = []
        pred_logits = []
        t0 = time.time()
        for idx, batch in enumerate(batcher.get_test_batch()):
            t1 = time.time()
            pred_lbl, lbl_logits = model.predict(batch)

            #lbl_logits = lbl_logits.cpu().numpy()
              
            
            pred_labels.extend(pred_lbl.cpu().numpy().tolist())
            pred_logits.extend(lbl_logits.cpu().numpy().tolist())
          
            list_idx = batch["input"]["idx"] if isinstance(batch["input"]["idx"], list) else batch["input"][
                "idx"].cpu().numpy().tolist()
            list_lbl = batch["output"]["true_lbl"] if "true_lbl" in batch["output"] else batch["output"]["lbl"]

            if config.dataset.lower() == 'fewglue/record':
                list_idx = batch["input"]["qas_idx"]
                list_lbl = batch["input"]["candidate_entity"]
                test_writer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy())
            else:
                test_writer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy())
            
            #added
            t2 = time.time()
            diff1 = t1-t0
            diff2 = t2-t1
            diff3 = t2-t0
            #json_dict = {'start_loop':diff1, 'inside_loop':diff2, 'once_through':diff3}
            #writefile = 'time_difference/'+config.exp_dir
            #if not os.path.exists(writefile):
            #    os.makedirs(writefile)
            #print(writefile)
            #writefile = writefile+'time.json'
            #with open(writefile, "a") as f:
            #    f.write(json.dumps(json_dict)+ '\n')
    t3 = time.time()
    print('total inference time: {}'.format(t3-t0))
    #altered        
    #print(pred_logits)        
    test_writer.flush_file()
    return pred_labels, np.array(pred_logits)