in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/utils.py [0:0]
def calc_raw_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
with torch.no_grad():
s = test_triplets[:, 0]
r = test_triplets[:, 1]
o = test_triplets[:, 2]
test_size = test_triplets.shape[0]
# perturb subject
ranks_s = perturb_and_get_raw_rank(embedding, w, o, r, s, test_size, eval_bz)
# perturb object
ranks_o = perturb_and_get_raw_rank(embedding, w, s, r, o, test_size, eval_bz)
ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = torch.mean(1.0 / ranks.float())
print("MRR (raw): {:.6f}".format(mrr.item()))
for hit in hits:
avg_count = torch.mean((ranks <= hit).float())
print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item()