def forward()

in hype_kg/codes/model.py [0:0]


    def forward(self, sample, rel_len, qtype, mode='single'):
        if qtype == 'chain-inter':
            assert mode == 'tail-batch'
            head_part, tail_part = sample
            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
            head_1 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
            head_2 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1)
            head = torch.cat([head_1, head_2], dim=0)
            if self.euo and self.geo == 'box':
                head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
                head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1)
                head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)

            tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)

            relation_11 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
            relation_12 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
            relation_2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
            relation = torch.cat([relation_11, relation_12, relation_2], dim=0)

            if self.geo == 'box':
                offset_11 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
                offset_12 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
                offset_2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
                offset = torch.cat([offset_11, offset_12, offset_2], dim=0)

        elif qtype == 'inter-chain' or qtype == 'union-chain':
            assert mode == 'tail-batch'
            head_part, tail_part = sample
            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
            head_1 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
            head_2 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
            head = torch.cat([head_1, head_2], dim=0)
            if self.euo and self.geo == 'box':
                head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
                head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
                head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)

            tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)

            relation_11 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
            relation_12 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
            relation_2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
            relation = torch.cat([relation_11, relation_12, relation_2], dim=0)

            if self.geo == 'box':
                offset_11 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
                offset_12 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
                offset_2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1).unsqueeze(1)
                offset = torch.cat([offset_11, offset_12, offset_2], dim=0)

        elif qtype == '2-inter' or qtype == '3-inter' or qtype == '2-union' or qtype == '3-union':
            if mode == 'single':
                batch_size, negative_sample_size = sample.size(0), 1

                head_1 = torch.index_select(self.entity_embedding, dim=0, index=sample[:, 0]).unsqueeze(1)
                head_2 = torch.index_select(self.entity_embedding, dim=0, index=sample[:, 2]).unsqueeze(1)
                head = torch.cat([head_1, head_2], dim=0)
                if self.euo and self.geo == 'box':
                    head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:, 0]).unsqueeze(1)
                    head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:, 2]).unsqueeze(1)
                    head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)
                if rel_len == 3:
                    head_3 = torch.index_select(self.entity_embedding, dim=0, index=sample[:, 4]).unsqueeze(1)
                    head = torch.cat([head, head_3], dim=0)
                    if self.euo and self.geo == 'box':
                        head_offset_3 = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:, 4]).unsqueeze(1)
                        head_offset = torch.cat([head_offset, head_offset_3], dim=0)

                tail = torch.index_select(self.entity_embedding, dim=0, index=sample[:,-1]).unsqueeze(1)
                if rel_len == 2:
                    tail = torch.cat([tail, tail], dim=0)
                elif rel_len == 3:
                    tail = torch.cat([tail, tail, tail], dim=0)

                relation_1 = torch.index_select(self.relation_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
                relation_2 = torch.index_select(self.relation_embedding, dim=0, index=sample[:,3]).unsqueeze(1).unsqueeze(1)
                relation = torch.cat([relation_1, relation_2], dim=0)
                if rel_len == 3:
                    relation_3 = torch.index_select(self.relation_embedding, dim=0, index=sample[:,5]).unsqueeze(1).unsqueeze(1)
                    relation = torch.cat([relation, relation_3], dim=0)

                if self.geo == 'box':
                    offset_1 = torch.index_select(self.offset_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
                    offset_2 = torch.index_select(self.offset_embedding, dim=0, index=sample[:,3]).unsqueeze(1).unsqueeze(1)
                    offset = torch.cat([offset_1, offset_2], dim=0)
                    if rel_len == 3:
                        offset_3 = torch.index_select(self.offset_embedding, dim=0, index=sample[:,5]).unsqueeze(1).unsqueeze(1)
                        offset = torch.cat([offset, offset_3], dim=0)

            elif mode == 'tail-batch':
                head_part, tail_part = sample
                batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
                
                head_1 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
                head_2 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
                head = torch.cat([head_1, head_2], dim=0)
                if self.euo and self.geo == 'box':
                    head_offset_1 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
                    head_offset_2 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1)
                    head_offset = torch.cat([head_offset_1, head_offset_2], dim=0)
                if rel_len == 3:
                    head_3 = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1)
                    head = torch.cat([head, head_3], dim=0)
                    if self.euo and self.geo == 'box':
                        head_offset_3 = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 4]).unsqueeze(1)
                        head_offset = torch.cat([head_offset, head_offset_3], dim=0)

                tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
                if rel_len == 2:
                    tail = torch.cat([tail, tail], dim=0)
                elif rel_len == 3:
                    tail = torch.cat([tail, tail, tail], dim=0)

                relation_1 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
                relation_2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
                relation = torch.cat([relation_1, relation_2], dim=0)
                if rel_len == 3:
                    relation_3 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 5]).unsqueeze(1).unsqueeze(1)
                    relation = torch.cat([relation, relation_3], dim=0)

                if self.geo == 'box':
                    offset_1 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
                    offset_2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
                    offset = torch.cat([offset_1, offset_2], dim=0)
                    if rel_len == 3:
                        offset_3 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 5]).unsqueeze(1).unsqueeze(1)
                        offset = torch.cat([offset, offset_3], dim=0)

        elif qtype == '1-chain' or qtype == '2-chain' or qtype == '3-chain':
            if mode == 'single':
                batch_size, negative_sample_size = sample.size(0), 1
                
                head = torch.index_select(self.entity_embedding, dim=0, index=sample[:,0]).unsqueeze(1)
                
                relation = torch.index_select(self.relation_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
                if self.geo == 'box':
                    offset = torch.index_select(self.offset_embedding, dim=0, index=sample[:,1]).unsqueeze(1).unsqueeze(1)
                    if self.euo:
                        head_offset = torch.index_select(self.entity_offset_embedding, dim=0, index=sample[:,0]).unsqueeze(1)
                if rel_len == 2 or rel_len == 3:
                    relation2 = torch.index_select(self.relation_embedding, dim=0, index=sample[:, 2]).unsqueeze(1).unsqueeze(1)
                    relation = torch.cat([relation, relation2], 1)
                    if self.geo == 'box':
                        offset2 = torch.index_select(self.offset_embedding, dim=0, index=sample[:, 2]).unsqueeze(1).unsqueeze(1)
                        offset = torch.cat([offset, offset2], 1)
                if rel_len == 3:
                    relation3 = torch.index_select(self.relation_embedding, dim=0, index=sample[:, 3]).unsqueeze(1).unsqueeze(1)
                    relation = torch.cat([relation, relation3], 1)
                    if self.geo == 'box':
                        offset3 = torch.index_select(self.offset_embedding, dim=0, index=sample[:, 3]).unsqueeze(1).unsqueeze(1)
                        offset = torch.cat([offset, offset3], 1)
                
                assert relation.size(1) == rel_len
                if self.geo == 'box':
                    assert offset.size(1) == rel_len
                    
                tail = torch.index_select(self.entity_embedding, dim=0, index=sample[:,-1]).unsqueeze(1)

            elif mode == 'tail-batch':
                head_part, tail_part = sample
                batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
                
                head = torch.index_select(self.entity_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
                
                relation = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
                if self.geo == 'box':
                    offset = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1).unsqueeze(1)
                    if self.euo:
                        head_offset = torch.index_select(self.entity_offset_embedding, dim=0, index=head_part[:, 0]).unsqueeze(1)
                if rel_len == 2 or rel_len == 3:
                    relation2 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
                    relation = torch.cat([relation, relation2], 1)
                    if self.geo == 'box':
                        offset2 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 2]).unsqueeze(1).unsqueeze(1)
                        offset = torch.cat([offset, offset2], 1)
                if rel_len == 3:
                    relation3 = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
                    relation = torch.cat([relation, relation3], 1)
                    if self.geo == 'box':
                        offset3 = torch.index_select(self.offset_embedding, dim=0, index=head_part[:, 3]).unsqueeze(1).unsqueeze(1)
                        offset = torch.cat([offset, offset3], 1)

                assert relation.size(1) == rel_len
                if self.geo == 'box':
                    assert offset.size(1) == rel_len
                
                tail = torch.index_select(self.entity_embedding, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)
            
        else:
            raise ValueError('mode %s not supported' % mode)
            
        model_func = {
            'BoxTransE': self.BoxTransE,
            'TransE': self.TransE,
        }
        if self.geo == 'vec':
            offset = None
            head_offset = None
        if self.geo == 'box':
            if not self.euo:
                head_offset = None
        
        if self.model_name in model_func:
            if qtype == '2-inter' or qtype == '3-inter' or qtype == '2-union' or qtype == '3-union':
                score, score_cen, offset_norm, score_cen_plus, _ = model_func[self.model_name](head, relation, tail, mode, offset, head_offset, 1, qtype)
            else:
                score, score_cen, offset_norm, score_cen_plus, _ = model_func[self.model_name](head, relation, tail, mode, offset, head_offset, rel_len, qtype)
        else:
            raise ValueError('model %s not supported' % self.model_name)
        
        return score, score_cen, offset_norm, score_cen_plus, None, None