in hype_kg/codes/model.py [0:0]
def BoxTransE(self, head, relation, tail, mode, offset, head_offset, rel_len, qtype):
if qtype == 'chain-inter':
relations = torch.chunk(relation, 3, dim=0)
offsets = torch.chunk(offset, 3, dim=0)
if self.euo:
head_offsets = torch.chunk(head_offset, 2, dim=0)
heads = torch.chunk(head, 2, dim=0)
query_center_1 = heads[0] + relations[0][:,0,:,:] + relations[1][:,0,:,:]
query_center_2 = heads[1] + relations[2][:,0,:,:]
if self.euo:
query_min_1 = query_center_1 - 0.5 * self.func(head_offsets[0]) - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[1][:,0,:,:])
query_min_2 = query_center_2 - 0.5 * self.func(head_offsets[1]) - 0.5 * self.func(offsets[2][:,0,:,:])
query_max_1 = query_center_1 + 0.5 * self.func(head_offsets[0]) + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[1][:,0,:,:])
query_max_2 = query_center_2 + 0.5 * self.func(head_offsets[1]) + 0.5 * self.func(offsets[2][:,0,:,:])
else:
query_min_1 = query_center_1 - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[1][:,0,:,:])
query_min_2 = query_center_2 - 0.5 * self.func(offsets[2][:,0,:,:])
query_max_1 = query_center_1 + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[1][:,0,:,:])
query_max_2 = query_center_2 + 0.5 * self.func(offsets[2][:,0,:,:])
query_center_1 = query_center_1.squeeze(1)
query_center_2 = query_center_2.squeeze(1)
offset_1 = (query_max_1 - query_min_1).squeeze(1)
offset_2 = (query_max_2 - query_min_2).squeeze(1)
new_query_center = self.center_sets(query_center_1, offset_1, query_center_2, offset_2)
new_offset = self.offset_sets(query_center_1, offset_1, query_center_2, offset_2)
new_query_min = (new_query_center - 0.5 * self.func(new_offset)).unsqueeze(1)
new_query_max = (new_query_center + 0.5 * self.func(new_offset)).unsqueeze(1)
score_offset = F.relu(new_query_min - tail) + F.relu(tail - new_query_max)
score_center = new_query_center.unsqueeze(1) - tail
score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tail)) - new_query_center.unsqueeze(1)
elif qtype == 'inter-chain':
relations = torch.chunk(relation, 3, dim=0)
offsets = torch.chunk(offset, 3, dim=0)
if self.euo:
head_offsets = torch.chunk(head_offset, 2, dim=0)
heads = torch.chunk(head, 2, dim=0)
query_center_1 = heads[0] + relations[0][:,0,:,:]
query_center_2 = heads[1] + relations[1][:,0,:,:]
if self.euo:
query_min_1 = query_center_1 - 0.5 * self.func(head_offsets[0]) - 0.5 * self.func(offsets[0][:,0,:,:])
query_min_2 = query_center_2 - 0.5 * self.func(head_offsets[1]) - 0.5 * self.func(offsets[1][:,0,:,:])
query_max_1 = query_center_1 + 0.5 * self.func(head_offsets[0]) + 0.5 * self.func(offsets[0][:,0,:,:])
query_max_2 = query_center_2 + 0.5 * self.func(head_offsets[1]) + 0.5 * self.func(offsets[1][:,0,:,:])
else:
query_min_1 = query_center_1 - 0.5 * self.func(offsets[0][:,0,:,:])
query_min_2 = query_center_2 - 0.5 * self.func(offsets[1][:,0,:,:])
query_max_1 = query_center_1 + 0.5 * self.func(offsets[0][:,0,:,:])
query_max_2 = query_center_2 + 0.5 * self.func(offsets[1][:,0,:,:])
query_center_1 = query_center_1.squeeze(1)
query_center_2 = query_center_2.squeeze(1)
offset_1 = (query_max_1 - query_min_1).squeeze(1)
offset_2 = (query_max_2 - query_min_2).squeeze(1)
conj_query_center = self.center_sets(query_center_1, offset_1, query_center_2, offset_2).unsqueeze(1)
new_query_center = conj_query_center + relations[2][:,0,:,:]
new_offset = self.offset_sets(query_center_1, offset_1, query_center_2, offset_2).unsqueeze(1)
new_query_min = new_query_center - 0.5 * self.func(new_offset) - 0.5 * self.func(offsets[2][:,0,:,:])
new_query_max = new_query_center + 0.5 * self.func(new_offset) + 0.5 * self.func(offsets[2][:,0,:,:])
score_offset = F.relu(new_query_min - tail) + F.relu(tail - new_query_max)
score_center = new_query_center - tail
score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tail)) - new_query_center
elif qtype == 'union-chain':
relations = torch.chunk(relation, 3, dim=0)
offsets = torch.chunk(offset, 3, dim=0)
if self.euo:
head_offsets = torch.chunk(head_offset, 2, dim=0)
heads = torch.chunk(head, 2, dim=0)
query_center_1 = heads[0] + relations[0][:,0,:,:] + relations[2][:,0,:,:]
query_center_2 = heads[1] + relations[1][:,0,:,:] + relations[2][:,0,:,:]
if self.euo:
query_min_1 = query_center_1 - 0.5 * self.func(head_offsets[0]) - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
query_min_2 = query_center_2 - 0.5 * self.func(head_offsets[1]) - 0.5 * self.func(offsets[1][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
query_max_1 = query_center_1 + 0.5 * self.func(head_offsets[0]) + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
query_max_2 = query_center_2 + 0.5 * self.func(head_offsets[1]) + 0.5 * self.func(offsets[1][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
else:
query_min_1 = query_center_1 - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
query_min_2 = query_center_2 - 0.5 * self.func(offsets[1][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
query_max_1 = query_center_1 + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
query_max_2 = query_center_2 + 0.5 * self.func(offsets[1][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
new_query_min = torch.stack([query_min_1, query_min_2], dim=0)
new_query_max = torch.stack([query_max_1, query_max_2], dim=0)
new_query_center = torch.stack([query_center_1, query_center_2], dim=0)
score_offset = F.relu(new_query_min - tail) + F.relu(tail - new_query_max)
score_center = new_query_center - tail
score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tail)) - new_query_center
else:
query_center = head
for rel in range(rel_len):
query_center = query_center + relation[:,rel,:,:]
if self.euo:
query_min = query_center - 0.5 * self.func(head_offset)
query_max = query_center + 0.5 * self.func(head_offset)
else:
query_min = query_center
query_max = query_center
for rel in range(0, rel_len):
query_min = query_min - 0.5 * self.func(offset[:,rel,:,:])
query_max = query_max + 0.5 * self.func(offset[:,rel,:,:])
if 'inter' not in qtype and 'union' not in qtype:
score_offset = F.relu(query_min - tail) + F.relu(tail - query_max)
score_center = query_center - tail
score_center_plus = torch.min(query_max, torch.max(query_min, tail)) - query_center
else:
rel_len = int(qtype.split('-')[0])
assert rel_len > 1
queries_min = torch.chunk(query_min, rel_len, dim=0)
queries_max = torch.chunk(query_max, rel_len, dim=0)
queries_center = torch.chunk(query_center, rel_len, dim=0)
tails = torch.chunk(tail, rel_len, dim=0)
offsets = query_max - query_min
offsets = torch.chunk(offsets, rel_len, dim=0)
if 'inter' in qtype:
if rel_len == 2:
new_query_center = self.center_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1),
queries_center[1].squeeze(1), offsets[1].squeeze(1))
new_offset = self.offset_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1),
queries_center[1].squeeze(1), offsets[1].squeeze(1))
elif rel_len == 3:
new_query_center = self.center_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1),
queries_center[1].squeeze(1), offsets[1].squeeze(1),
queries_center[2].squeeze(1), offsets[2].squeeze(1))
new_offset = self.offset_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1),
queries_center[1].squeeze(1), offsets[1].squeeze(1),
queries_center[2].squeeze(1), offsets[2].squeeze(1))
new_query_min = (new_query_center - 0.5*self.func(new_offset)).unsqueeze(1)
new_query_max = (new_query_center + 0.5*self.func(new_offset)).unsqueeze(1)
score_offset = F.relu(new_query_min - tails[0]) + F.relu(tails[0] - new_query_max)
score_center = new_query_center.unsqueeze(1) - tails[0]
score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tails[0])) - new_query_center.unsqueeze(1)
elif 'union' in qtype:
new_query_min = torch.stack(queries_min, dim=0)
new_query_max = torch.stack(queries_max, dim=0)
new_query_center = torch.stack(queries_center, dim=0)
score_offset = F.relu(new_query_min - tails[0]) + F.relu(tails[0] - new_query_max)
score_center = new_query_center - tails[0]
score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tails[0])) - new_query_center
else:
assert False, 'qtype not exists: %s'%qtype
score = self.gamma.item() - torch.norm(score_offset, p=1, dim=-1)
score_center = self.gamma2.item() - torch.norm(score_center, p=1, dim=-1)
score_center_plus = self.gamma.item() - torch.norm(score_offset, p=1, dim=-1) - self.cen * torch.norm(score_center_plus, p=1, dim=-1)
if 'union' in qtype:
score = torch.max(score, dim=0)[0]
score_center = torch.max(score_center, dim=0)[0]
score_center_plus = torch.max(score_center_plus, dim=0)[0]
return score, score_center, torch.mean(torch.norm(offset, p=2, dim=2).squeeze(1)), score_center_plus, None