def calc_mAP()

in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/utils.py [0:0]


def calc_mAP(embedding, w, train_triplets, test_triplets):
    sources = torch.cat((test_triplets[:, 0], train_triplets[:, 0]))
    sinks = torch.cat((test_triplets[:, 2], train_triplets[:, 2]))
    adj_list = convert_to_adj_list(sources.numpy(), sinks.numpy())
    aps = []
    for node in test_triplets[:, 0]:
        embed_i = node.repeat(embedding.shape[0] - 1,)
        embed_j = torch.tensor(list(range(0, node)) + list(range(node + 1, embedding.shape[0])))
        score = torch.sum(w * embedding[embed_i] * embedding[embed_j], dim=1)
        pred_proba = torch.sigmoid(score).detach().numpy()
        labels = np.zeros(embedding.shape[0])
        labels_ones_idx = adj_list.get(node, [])
        labels[labels_ones_idx] = 1
        labels = np.concatenate((labels[:node], labels[node + 1:]))
        ap = average_precision_score(labels, pred_proba)
        aps.append(ap)

    return np.mean(aps)