in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/utils.py [0:0]
def perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
""" Perturb object in the triplets
"""
num_entities = embedding.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
filtered_o = filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities)
target_o_idx = int((filtered_o == target_o).nonzero())
emb_s = embedding[target_s]
emb_r = w[target_r]
emb_o = embedding[filtered_o]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_o_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)