in hype_kg/codes/model.py [0:0]
def forward(self, sample, rel_len, qtype, mode='single'):
if qtype == 'chain-inter':
assert mode == 'tail-batch'
head_part, tail_part = sample
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head_1 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
head_2 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1)
head = torch.cat([head_1, head_2], dim=0)
if self.euo and self.geo == 'box':
head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1)
head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)
tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
relation_11 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
relation_12 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
relation_2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation_11, relation_12, relation_2], dim=0)
if self.geo == 'box':
offset_11 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
offset_12 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
offset_2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset_11, offset_12, offset_2], dim=0)
elif qtype == 'inter-chain' or qtype == 'union-chain':
assert mode == 'tail-batch'
head_part, tail_part = sample
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head_1 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
head_2 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
head = torch.cat([head_1, head_2], dim=0)
if self.euo and self.geo == 'box':
head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)
tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
relation_11 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
relation_12 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
relation_2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation_11, relation_12, relation_2], dim=0)
if self.geo == 'box':
offset_11 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
offset_12 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
offset_2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset_11, offset_12, offset_2], dim=0)
elif qtype == '2-inter' or qtype == '3-inter' or qtype == '2-union' or qtype == '3-union':
if mode == 'single':
batch_size, negative_sample_size = sample.size(0), 1
head_1 = torch.index_select(self.entity_embedding, dim=0, index=sample[:, 0]).unsqueeze(1)
head_2 = torch.index_select(self.entity_embedding, dim=0, index=sample[:, 2]).unsqueeze(1)
head = torch.cat([head_1, head_2], dim=0)
if self.euo and self.geo == 'box':
head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:, 0]).unsqueeze(1)
head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:, 2]).unsqueeze(1)
head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)
if rel_len == 3:
head_3 = torch.index_select(self.entity_embedding, dim=0, index=sample[:, 4]).unsqueeze(1)
head = torch.cat([head, head_3], dim=0)
if self.euo and self.geo == 'box':
head_offset_3 = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:, 4]).unsqueeze(1)
head_offset = torch.cat([head_offset, head_offset_3], dim=0)
tail = torch.index_select(self.entity_embedding, dim=0, index=sample[:,-1]).unsqueeze(1)
if rel_len == 2:
tail = torch.cat([tail, tail], dim=0)
elif rel_len == 3:
tail = torch.cat([tail, tail, tail], dim=0)
relation_1 = torch.index_select(self.relation_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
relation_2 = torch.index_select(self.relation_embedding, dim=0, index=sample[:,3]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation_1, relation_2], dim=0)
if rel_len == 3:
relation_3 = torch.index_select(self.relation_embedding, dim=0, index=sample[:,5]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation, relation_3], dim=0)
if self.geo == 'box':
offset_1 = torch.index_select(self.offset_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
offset_2 = torch.index_select(self.offset_embedding, dim=0, index=sample[:,3]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset_1, offset_2], dim=0)
if rel_len == 3:
offset_3 = torch.index_select(self.offset_embedding, dim=0, index=sample[:,5]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset, offset_3], dim=0)
elif mode == 'tail-batch':
head_part, tail_part = sample
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head_1 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
head_2 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
head = torch.cat([head_1, head_2], dim=0)
if self.euo and self.geo == 'box':
head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)
if rel_len == 3:
head_3 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1)
head = torch.cat([head, head_3], dim=0)
if self.euo and self.geo == 'box':
head_offset_3 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1)
head_offset = torch.cat([head_offset, head_offset_3], dim=0)
tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
if rel_len == 2:
tail = torch.cat([tail, tail], dim=0)
elif rel_len == 3:
tail = torch.cat([tail, tail, tail], dim=0)
relation_1 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
relation_2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation_1, relation_2], dim=0)
if rel_len == 3:
relation_3 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 5]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation, relation_3], dim=0)
if self.geo == 'box':
offset_1 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
offset_2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset_1, offset_2], dim=0)
if rel_len == 3:
offset_3 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 5]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset, offset_3], dim=0)
elif qtype == '1-chain' or qtype == '2-chain' or qtype == '3-chain':
if mode == 'single':
batch_size, negative_sample_size = sample.size(0), 1
head = torch.index_select(self.entity_embedding, dim=0, index=sample[:,0]).unsqueeze(1)
relation = torch.index_select(self.relation_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
if self.geo == 'box':
offset = torch.index_select(self.offset_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
if self.euo:
head_offset = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:,0]).unsqueeze(1)
if rel_len == 2 or rel_len == 3:
relation2 = torch.index_select(self.relation_embedding, dim=0, index=sample[:, 2]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation, relation2], 1)
if self.geo == 'box':
offset2 = torch.index_select(self.offset_embedding, dim=0, index=sample[:, 2]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset, offset2], 1)
if rel_len == 3:
relation3 = torch.index_select(self.relation_embedding, dim=0, index=sample[:, 3]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation, relation3], 1)
if self.geo == 'box':
offset3 = torch.index_select(self.offset_embedding, dim=0, index=sample[:, 3]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset, offset3], 1)
assert relation.size(1) == rel_len
if self.geo == 'box':
assert offset.size(1) == rel_len
tail = torch.index_select(self.entity_embedding, dim=0, index=sample[:,-1]).unsqueeze(1)
elif mode == 'tail-batch':
head_part, tail_part = sample
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
relation = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
if self.geo == 'box':
offset = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
if self.euo:
head_offset = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
if rel_len == 2 or rel_len == 3:
relation2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation, relation2], 1)
if self.geo == 'box':
offset2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset, offset2], 1)
if rel_len == 3:
relation3 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
relation = torch.cat([relation, relation3], 1)
if self.geo == 'box':
offset3 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
offset = torch.cat([offset, offset3], 1)
assert relation.size(1) == rel_len
if self.geo == 'box':
assert offset.size(1) == rel_len
tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
else:
raise ValueError('mode %s not supported' % mode)
model_func = {
'BoxTransE': self.BoxTransE,
'TransE': self.TransE,
}
if self.geo == 'vec':
offset = None
head_offset = None
if self.geo == 'box':
if not self.euo:
head_offset = None
if self.model_name in model_func:
if qtype == '2-inter' or qtype == '3-inter' or qtype == '2-union' or qtype == '3-union':
score, score_cen, offset_norm, score_cen_plus, _ = model_func[self.model_name](head, relation, tail, mode, offset, head_offset, 1, qtype)
else:
score, score_cen, offset_norm, score_cen_plus, _ = model_func[self.model_name](head, relation, tail, mode, offset, head_offset, rel_len, qtype)
else:
raise ValueError('model %s not supported' % self.model_name)
return score, score_cen, offset_norm, score_cen_plus, None, None