in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/utils.py [0:0]
def perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
""" Perturb subject in the triplets
"""
num_entities = embedding.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
filtered_s = filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities)
target_s_idx = int((filtered_s == target_s).nonzero())
emb_s = embedding[filtered_s]
emb_r = w[target_r]
emb_o = embedding[target_o]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_s_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)