in hype_kg/codes/model.py [0:0]
def TransE(self, head, relation, tail, mode, offset, head_offset, rel_len, qtype):
if qtype == 'chain-inter':
relations = torch.chunk(relation, 3, dim=0)
heads = torch.chunk(head, 2, dim=0)
score_1 = (heads[0] + relations[0][:,0,:,:] + relations[1][:,0,:,:]).squeeze(1)
score_2 = (heads[1] + relations[2][:,0,:,:]).squeeze(1)
conj_score = self.deepsets(score_1, None, score_2, None).unsqueeze(1)
score = conj_score - tail
elif qtype == 'inter-chain':
relations = torch.chunk(relation, 3, dim=0)
heads = torch.chunk(head, 2, dim=0)
score_1 = (heads[0] + relations[0][:,0,:,:]).squeeze(1)
score_2 = (heads[1] + relations[1][:,0,:,:]).squeeze(1)
conj_score = self.deepsets(score_1, None, score_2, None).unsqueeze(1)
score = conj_score + relations[2][:,0,:,:] - tail
elif qtype == 'union-chain':
relations = torch.chunk(relation, 3, dim=0)
heads = torch.chunk(head, 2, dim=0)
score_1 = heads[0] + relations[0][:,0,:,:] + relations[2][:,0,:,:]
score_2 = heads[1] + relations[1][:,0,:,:] + relations[2][:,0,:,:]
conj_score = torch.stack([score_1, score_2], dim=0)
score = conj_score - tail
else:
score = head
for rel in range(rel_len):
score = score + relation[:,rel,:,:]
if 'inter' not in qtype and 'union' not in qtype:
score = score - tail
else:
rel_len = int(qtype.split('-')[0])
assert rel_len > 1
score = score.squeeze(1)
scores = torch.chunk(score, rel_len, dim=0)
tails = torch.chunk(tail, rel_len, dim=0)
if 'inter' in qtype:
if rel_len == 2:
conj_score = self.deepsets(scores[0], None, scores[1], None)
elif rel_len == 3:
conj_score = self.deepsets(scores[0], None, scores[1], None, scores[2], None)
conj_score = conj_score.unsqueeze(1)
score = conj_score - tails[0]
elif 'union' in qtype:
conj_score = torch.stack(scores, dim=0)
score = conj_score - tails[0]
else:
assert False, 'qtype not exist: %s'%qtype
score = self.gamma.item() - torch.norm(score, p=1, dim=-1)
if 'union' in qtype:
score = torch.max(score, dim=0)[0]
if qtype == '2-union':
score = score.unsqueeze(0)
return score, None, None, 0., []