in python/dglke/models/ke_model.py [0:0]
def _topk_exclude_pos(self, score, idx, head, rel, tail, topk, exec_mode, exclude_mode):
""" Generate topk most relevent triplets and corresponding scores.
It takes following steps:
1) find topk elements
2) sort topk elements in descending order
3) call _exclude_pos if figure out existing edges
"""
if exclude_mode == 'exclude':
if idx.shape[0] < topk * 4: # TODO(xiangsx): Find a better value of topk * n
topk_score, topk_sidx = th.topk(score, k=idx.shape[0], dim=0)
sidx = th.argsort(topk_score, dim=0, descending=True)
sidx = topk_sidx[sidx]
result = self._exclude_pos(sidx=sidx,
score=topk_score,
idx=idx,
head=head,
rel=rel,
tail=tail,
topk=topk,
exec_mode=exec_mode,
exclude_mode=exclude_mode)
else:
topk_score, topk_sidx = th.topk(score, k= topk * 4, dim=0)
sidx = th.argsort(topk_score, dim=0, descending=True)
sidx = topk_sidx[sidx]
result = self._exclude_pos(sidx=sidx,
score=topk_score,
idx=idx,
head=head,
rel=rel,
tail=tail,
topk=topk,
exec_mode=exec_mode,
exclude_mode=exclude_mode)
if len(result) < topk:
sidx = th.argsort(score, dim=0, descending=True)
result = self._exclude_pos(sidx=sidx,
score=score[sidx],
idx=idx,
head=head,
rel=rel,
tail=tail,
topk=topk,
exec_mode=exec_mode,
exclude_mode=exclude_mode)
else:
topk = idx.shape[0] if idx.shape[0] < topk else topk
topk_score, topk_sidx = th.topk(score, k=topk, dim=0)
sidx = th.argsort(topk_score, dim=0, descending=True)
sidx = topk_sidx[sidx]
result = self._exclude_pos(sidx=sidx,
score=topk_score,
idx=idx,
head=head,
rel=rel,
tail=tail,
topk=topk,
exec_mode=exec_mode,
exclude_mode=exclude_mode)
return result