in python/dglke/models/ke_model.py [0:0]
def _exclude_pos(self, sidx, score, idx, head, rel, tail, topk, exec_mode, exclude_mode):
g = self.graph
num_triples = idx.shape[0]
num_head = 1 if exec_mode == 'batch_head' else head.shape[0]
num_rel = 1 if exec_mode == 'batch_rel' else rel.shape[0]
num_tail = 1 if exec_mode == 'batch_tail' else tail.shape[0]
res_head = []
res_rel = []
res_tail = []
res_score = []
result = []
if exclude_mode == 'exclude':
# exclude existing edges
cur_k = 0
batch_size = topk
while (cur_k < num_triples):
cur_sidx = sidx[cur_k:cur_k + batch_size if cur_k + batch_size < num_triples else num_triples]
cur_score = score[cur_k:cur_k + batch_size if cur_k + batch_size < num_triples else num_triples]
cur_idx = idx[cur_sidx]
if exec_mode == 'triplet_wise':
cur_head = head[cur_idx]
cur_rel = rel[cur_idx]
cur_tail = tail[cur_idx]
elif exec_mode == 'all':
tail_idx = cur_idx % num_tail
cur_idx = floor_divide(cur_idx, num_tail)
rel_idx = cur_idx % num_rel
cur_idx = floor_divide(cur_idx, num_rel)
head_idx = cur_idx % num_head
cur_head = head[head_idx]
cur_rel = rel[rel_idx]
cur_tail = tail[tail_idx]
elif exec_mode == 'batch_head':
tail_idx = cur_idx % num_tail
cur_idx = floor_divide(cur_idx, num_tail)
rel_idx = cur_idx % num_rel
cur_head = th.full((cur_sidx.shape[0],), head, dtype=head.dtype)
cur_rel = rel[rel_idx]
cur_tail = tail[tail_idx]
elif exec_mode == 'batch_rel':
tail_idx = cur_idx % num_tail
cur_idx = floor_divide(cur_idx, num_tail)
head_idx = cur_idx % num_head
cur_head = head[head_idx]
cur_rel = th.full((cur_sidx.shape[0],), rel, dtype=rel.dtype)
cur_tail = tail[tail_idx]
elif exec_mode == 'batch_tail':
rel_idx = cur_idx % num_rel
cur_idx = floor_divide(cur_idx, num_rel)
head_idx = cur_idx % num_head
cur_head = head[head_idx]
cur_rel = rel[rel_idx]
cur_tail = th.full((cur_sidx.shape[0],), tail, dtype=tail.dtype)
# Find exising edges
# It is expacted that the existing edges are much less than triples
# The idea is: 1) we get existing edges using g.edge_ids
# 2) sort edges according to source node id (O(nlog(n)), n is number of edges)
# 3) sort candidate triples according to cur_head (O(mlog(m)), m is number of cur_head nodes)
# 4) go over all candidate triples and compare with existing edges,
# as both edges and candidate triples are sorted. filtering edges out
# will take only O(n+m)
# 5) sort the score again it taks O(klog(k))
uid, vid, eid = g.edge_ids(cur_head, cur_tail, return_uv=True)
rid = g.edata[self._etid_field][eid]
for i in range(cur_head.shape[0]):
h = cur_head[i]
r = cur_rel[i]
t = cur_tail[i]
h_where = uid == h
t_where = vid[h_where] == t
r_where = rid[h_where][t_where]
edge_exist = False
if r_where.shape[0] > 0:
for c_r in r_where:
if c_r == r:
edge_exist = True
break
if edge_exist is False:
res_head.append(h)
res_rel.append(r)
res_tail.append(t)
res_score.append(cur_score[i])
if len(res_head) >= topk:
break
cur_k += batch_size
batch_size = topk - len(res_head) # check more edges
batch_size = 16 if batch_size < 16 else batch_size # avoid tailing issue
res_head = th.tensor(res_head)
res_rel = th.tensor(res_rel)
res_tail = th.tensor(res_tail)
res_score = th.tensor(res_score)
sidx = th.argsort(res_score, dim=0, descending=True)
sidx = sidx[:topk] if topk < sidx.shape[0] else sidx
result.append((res_head[sidx],
res_rel[sidx],
res_tail[sidx],
res_score[sidx],
None))
else:
# including the existing edges in the result
topk = topk if topk < num_triples else num_triples
sidx = sidx[:topk]
idx = idx[sidx]
if exec_mode == 'triplet_wise':
head = head[idx]
rel = rel[idx]
tail = tail[idx]
elif exec_mode == 'all':
tail_idx = idx % num_tail
idx = floor_divide(idx, num_tail)
rel_idx = idx % num_rel
idx = floor_divide(idx, num_rel)
head_idx = idx % num_head
head = head[head_idx]
rel = rel[rel_idx]
tail = tail[tail_idx]
elif exec_mode == 'batch_head':
tail_idx = idx % num_tail
idx = floor_divide(idx, num_tail)
rel_idx = idx % num_rel
head = th.full((topk,), head, dtype=head.dtype)
rel = rel[rel_idx]
tail = tail[tail_idx]
elif exec_mode == 'batch_rel':
tail_idx = idx % num_tail
idx = floor_divide(idx, num_tail)
head_idx = idx % num_head
head = head[head_idx]
rel = th.full((topk,), rel, dtype=rel.dtype)
tail = tail[tail_idx]
elif exec_mode == 'batch_tail':
rel_idx = idx % num_rel
idx = floor_divide(idx, num_rel)
head_idx = idx % num_head
head = head[head_idx]
rel = rel[rel_idx]
tail = th.full((topk,), tail, dtype=tail.dtype)
if exclude_mode == 'mask':
# Find exising edges
# It is expacted that the existing edges are much less than triples
# The idea is: 1) we get existing edges using g.edge_ids
# 2) sort edges according to source node id (O(nlog(n)), n is number of edges)
# 3) sort candidate triples according to cur_head (O(mlog(m)), m is number of cur_head nodes)
# 4) go over all candidate triples and compare with existing edges and mask them,
# as both edges and candidate triples are sorted. filtering edges out
# will take only O(n+m)
uid, vid, eid = g.edge_ids(head, tail, return_uv=True)
rid = g.edata[self._etid_field][eid]
mask = th.full((head.shape[0],), False, dtype=th.bool)
if len(uid) > 0:
for i in range(head.shape[0]):
h = head[i]
r = rel[i]
t = tail[i]
h_where = uid == h
t_where = vid[h_where] == t
r_where = rid[h_where][t_where]
if r_where.shape[0] > 0:
for c_r in r_where:
if c_r == r:
mask[i] = True
break
result.append((head, rel, tail, score, mask))
else:
result.append((head, rel, tail, score, None))
return result