def evaluate()

in mico/evaluate/evaluate.py [0:0]


def evaluate(model, eval_loader, num_batches=-1, device=None):
    """This function evaluates the model during training.

    Parameters
    ----------
    model : MutualInfoCotrain object
        This is the MICO model we have been training.
    eval_loader : PyTorch Dataloader object
        This is the dataloader used for evaluation. It will be val_dataloader and test_dataloader.
    num_batches : int
        If this is set to be a non-negative number, 
        the evaluation will only be done with the first n batches in the dataloader where n = num_batches.
    device : int (for multi-GPU) or string ('cpu' or 'cuda')
        This is the GPU index since we use DistributedDataParallel,
        and the index of the GPU card is also the index of the sub-process.
    
    Returns
    -------
    perf : dictionary
        This performance metrics are in a dictionary:
         {'AUC': area under the curve (search cost v.s. top-1 coverage) in percentage;
          'top1_cov': top-1 coverage in percentage,
                      in how much proportion of the data (query-document pair), the top-1 cluster for query routing 
                      and the top-1 cluster for document assignment are the same cluster;
          'H__Z_X': the entropy of cluster distribution of query routing (to their top-1 cluster) in this evaluation data;
          'H__Z_Y': the entropy of cluster distribution of document assigment (to their top-1 cluster) in this evaluation data;
         }
    """
    model.eval()
    with torch.no_grad():
        encoding_chunks_query = []
        encoding_chunks_document = []
        for batch_idx, (query, document) in enumerate(tqdm(eval_loader)):
            if batch_idx == num_batches:
                break
            query_prob = model.forward(query=query, forward_method="encode_query", device=device)
            document_prob = model.forward(document=document, forward_method="encode_doc", device=device)
            encoding_chunks_query.append(query_prob)
            encoding_chunks_document.append(document_prob)
        encodings_query = torch.cat(encoding_chunks_query, 0).float()
        encodings_document = torch.cat(encoding_chunks_document, 0).float()
        perf = monitor_metrics(encodings_query, encodings_document)
    model.train()
    return perf