def monitor_metrics()

in mico/utils/utils.py [0:0]


def monitor_metrics(encodings_x, encodings_y):
    """This function helps evaluate the model during training.

    Parameters
    ----------
    encodings_x : `torch.tensor`
        This is the probability of queries routed to each cluster according to MICO model.
    encodings_y : `torch.tensor`
        This is the probability of document assigned to each cluster according to MICO model.
    
    Returns
    -------
    perf : dictionary
        This performance metrics are in a dictionary:
         {'AUC': area under the curve (search cost v.s. top-1 coverage) in percentage;
          'top1_cov': top-1 coverage in percentage,
                      in how much proportion of the data (query-document pair), the top-1 cluster for query routing 
                      and the top-1 cluster for document assignment are the same cluster;
          'H__Z_X': the entropy of cluster distribution of query routing (to their top-1 cluster) in this evaluation data;
          'H__Z_Y': the entropy of cluster distribution of document assigment (to their top-1 cluster) in this evaluation data;
         }
    """
    argmax_x = torch.argmax(encodings_x, dim=1)
    argmax_y = torch.argmax(encodings_y, dim=1)
    num_correct = torch.sum(argmax_x == argmax_y).float()
    acc = num_correct / encodings_x.size(0)
    enc_x_list = []
    enc_y_list = []
    enc_x_prob = []
    enc_y_prob = []
    for i in range(64):
        count_x = torch.sum(argmax_x == i).item()
        count_y = torch.sum(argmax_y == i).item()
        enc_x_list.append(count_x)
        enc_y_list.append(count_y)
        if count_x > 0:
            enc_x_prob.append(count_x)
        if count_y > 0:
            enc_y_prob.append(count_y)
    enc_y_list = torch.tensor(enc_y_list).float()
    enc_x_prob = torch.tensor(enc_x_prob).float()
    enc_y_prob = torch.tensor(enc_y_prob).float()

    enc_x_prob = enc_x_prob / torch.sum(enc_x_prob)
    enc_y_prob = enc_y_prob / torch.sum(enc_y_prob)
    enc_x_entropy = torch.sum(-enc_x_prob * torch.log(enc_x_prob)).item()
    enc_y_entropy = torch.sum(-enc_y_prob * torch.log(enc_y_prob)).item()

    argsort_x = torch.argsort(-encodings_x, dim=1)
    total_cost_mat = torch.cat(
        [torch.cumsum(enc_y_list[argsort_x[i]], 0).unsqueeze(0) for i in range(encodings_x.shape[0])], dim=0)
    total_cost_mean_curve = torch.mean(total_cost_mat, dim=0)
    total_cost_mean_curve = total_cost_mean_curve / total_cost_mean_curve[-1]
    coverage_mat = torch.cat(
        [torch.cumsum((argsort_x[i] == argmax_y[i]).float(), dim=0).unsqueeze(0) for i in range(encodings_x.shape[0])],
        dim=0)
    coverage_curve = torch.mean(coverage_mat, dim=0).cpu()
    metric_auc = metrics.auc(torch.cat([torch.tensor([0]).float(), total_cost_mean_curve], dim=0),
                             torch.cat([torch.tensor([0]).float(), coverage_curve], dim=0))

    return {'AUC': metric_auc * 100, 'top1_cov': acc.item() * 100,
            'H__Z_X': enc_x_entropy, 'H__Z_Y': enc_y_entropy}