in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/utils.py [0:0]
def calc_filtered_mrr(embedding, w, train_triplets, test_triplets, hits=[]):
with torch.no_grad():
s = test_triplets[:, 0]
r = test_triplets[:, 1]
o = test_triplets[:, 2]
test_size = test_triplets.shape[0]
triplets_to_filter = torch.cat([train_triplets, test_triplets]).tolist()
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter}
print('Perturbing subject...')
ranks_s = perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)
print('Perturbing object...')
ranks_o = perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)
ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = torch.mean(1.0 / ranks.float())
print("MRR (filtered): {:.6f}".format(mrr.item()))
for hit in hits:
avg_count = torch.mean((ranks <= hit).float())
print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item()