def BoxTransE()

in hype_kg/codes/model.py [0:0]


    def BoxTransE(self, head, relation, tail, mode, offset, head_offset, rel_len, qtype):
        if qtype == 'chain-inter':
            relations = torch.chunk(relation, 3, dim=0)
            offsets = torch.chunk(offset, 3, dim=0)
            if self.euo:
                head_offsets = torch.chunk(head_offset, 2, dim=0)
            
            heads = torch.chunk(head, 2, dim=0)

            query_center_1 = heads[0] + relations[0][:,0,:,:] + relations[1][:,0,:,:]
            query_center_2 = heads[1] + relations[2][:,0,:,:]
            if self.euo:
                query_min_1 = query_center_1 - 0.5 * self.func(head_offsets[0]) - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[1][:,0,:,:])
                query_min_2 = query_center_2 - 0.5 * self.func(head_offsets[1]) - 0.5 * self.func(offsets[2][:,0,:,:])
                query_max_1 = query_center_1 + 0.5 * self.func(head_offsets[0]) + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[1][:,0,:,:])
                query_max_2 = query_center_2 + 0.5 * self.func(head_offsets[1]) + 0.5 * self.func(offsets[2][:,0,:,:])
            else:
                query_min_1 = query_center_1 - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[1][:,0,:,:])
                query_min_2 = query_center_2 - 0.5 * self.func(offsets[2][:,0,:,:])
                query_max_1 = query_center_1 + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[1][:,0,:,:])
                query_max_2 = query_center_2 + 0.5 * self.func(offsets[2][:,0,:,:])
            query_center_1 = query_center_1.squeeze(1)
            query_center_2 = query_center_2.squeeze(1)
            offset_1 = (query_max_1 - query_min_1).squeeze(1)
            offset_2 = (query_max_2 - query_min_2).squeeze(1)
            new_query_center = self.center_sets(query_center_1, offset_1, query_center_2, offset_2)
            new_offset = self.offset_sets(query_center_1, offset_1, query_center_2, offset_2)
            new_query_min = (new_query_center - 0.5 * self.func(new_offset)).unsqueeze(1)
            new_query_max = (new_query_center + 0.5 * self.func(new_offset)).unsqueeze(1)
            score_offset = F.relu(new_query_min - tail) + F.relu(tail - new_query_max)
            score_center = new_query_center.unsqueeze(1) - tail
            score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tail)) - new_query_center.unsqueeze(1)

        elif qtype == 'inter-chain':
            relations = torch.chunk(relation, 3, dim=0)
            offsets = torch.chunk(offset, 3, dim=0)
            if self.euo:
                head_offsets = torch.chunk(head_offset, 2, dim=0)
            
            heads = torch.chunk(head, 2, dim=0)

            query_center_1 = heads[0] + relations[0][:,0,:,:]
            query_center_2 = heads[1] + relations[1][:,0,:,:]
            if self.euo:
                query_min_1 = query_center_1 - 0.5 * self.func(head_offsets[0]) - 0.5 * self.func(offsets[0][:,0,:,:])
                query_min_2 = query_center_2 - 0.5 * self.func(head_offsets[1]) - 0.5 * self.func(offsets[1][:,0,:,:])
                query_max_1 = query_center_1 + 0.5 * self.func(head_offsets[0]) + 0.5 * self.func(offsets[0][:,0,:,:])
                query_max_2 = query_center_2 + 0.5 * self.func(head_offsets[1]) + 0.5 * self.func(offsets[1][:,0,:,:])
            else:
                query_min_1 = query_center_1 - 0.5 * self.func(offsets[0][:,0,:,:])
                query_min_2 = query_center_2 - 0.5 * self.func(offsets[1][:,0,:,:])
                query_max_1 = query_center_1 + 0.5 * self.func(offsets[0][:,0,:,:])
                query_max_2 = query_center_2 + 0.5 * self.func(offsets[1][:,0,:,:])
            query_center_1 = query_center_1.squeeze(1)
            query_center_2 = query_center_2.squeeze(1)
            offset_1 = (query_max_1 - query_min_1).squeeze(1)
            offset_2 = (query_max_2 - query_min_2).squeeze(1)
            conj_query_center = self.center_sets(query_center_1, offset_1, query_center_2, offset_2).unsqueeze(1)
            new_query_center = conj_query_center + relations[2][:,0,:,:]
            new_offset = self.offset_sets(query_center_1, offset_1, query_center_2, offset_2).unsqueeze(1)
            new_query_min = new_query_center - 0.5 * self.func(new_offset) - 0.5 * self.func(offsets[2][:,0,:,:])
            new_query_max = new_query_center + 0.5 * self.func(new_offset) + 0.5 * self.func(offsets[2][:,0,:,:])
            score_offset = F.relu(new_query_min - tail) + F.relu(tail - new_query_max)
            score_center = new_query_center - tail
            score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tail)) - new_query_center

        elif qtype == 'union-chain':
            relations = torch.chunk(relation, 3, dim=0)
            offsets = torch.chunk(offset, 3, dim=0)
            if self.euo:
                head_offsets = torch.chunk(head_offset, 2, dim=0)
            
            heads = torch.chunk(head, 2, dim=0)

            query_center_1 = heads[0] + relations[0][:,0,:,:] + relations[2][:,0,:,:]
            query_center_2 = heads[1] + relations[1][:,0,:,:] + relations[2][:,0,:,:]
            if self.euo:
                query_min_1 = query_center_1 - 0.5 * self.func(head_offsets[0]) - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
                query_min_2 = query_center_2 - 0.5 * self.func(head_offsets[1]) - 0.5 * self.func(offsets[1][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
                query_max_1 = query_center_1 + 0.5 * self.func(head_offsets[0]) + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
                query_max_2 = query_center_2 + 0.5 * self.func(head_offsets[1]) + 0.5 * self.func(offsets[1][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
            else:
                query_min_1 = query_center_1 - 0.5 * self.func(offsets[0][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
                query_min_2 = query_center_2 - 0.5 * self.func(offsets[1][:,0,:,:]) - 0.5 * self.func(offsets[2][:,0,:,:])
                query_max_1 = query_center_1 + 0.5 * self.func(offsets[0][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])
                query_max_2 = query_center_2 + 0.5 * self.func(offsets[1][:,0,:,:]) + 0.5 * self.func(offsets[2][:,0,:,:])

            new_query_min = torch.stack([query_min_1, query_min_2], dim=0)
            new_query_max = torch.stack([query_max_1, query_max_2], dim=0)
            new_query_center = torch.stack([query_center_1, query_center_2], dim=0)
            score_offset = F.relu(new_query_min - tail) + F.relu(tail - new_query_max)
            score_center = new_query_center - tail
            score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tail)) - new_query_center

        else:
            query_center = head
            for rel in range(rel_len):
                query_center = query_center + relation[:,rel,:,:]
            if self.euo:
                query_min = query_center - 0.5 * self.func(head_offset)
                query_max = query_center + 0.5 * self.func(head_offset)
            else:
                query_min = query_center
                query_max = query_center
            for rel in range(0, rel_len):
                query_min = query_min - 0.5 * self.func(offset[:,rel,:,:])
                query_max = query_max + 0.5 * self.func(offset[:,rel,:,:])

            if 'inter' not in qtype and 'union' not in qtype:
                score_offset = F.relu(query_min - tail) + F.relu(tail - query_max)
                score_center = query_center - tail
                score_center_plus = torch.min(query_max, torch.max(query_min, tail)) - query_center
            else:
                rel_len = int(qtype.split('-')[0])
                assert rel_len > 1
                queries_min = torch.chunk(query_min, rel_len, dim=0)
                queries_max = torch.chunk(query_max, rel_len, dim=0)
                queries_center = torch.chunk(query_center, rel_len, dim=0)
                tails = torch.chunk(tail, rel_len, dim=0)
                offsets = query_max - query_min
                offsets = torch.chunk(offsets, rel_len, dim=0)
                if 'inter' in qtype:
                    if rel_len == 2:
                        new_query_center = self.center_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1), 
                                                        queries_center[1].squeeze(1), offsets[1].squeeze(1))
                        new_offset = self.offset_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1),
                                                        queries_center[1].squeeze(1), offsets[1].squeeze(1))
                        
                    elif rel_len == 3:
                        new_query_center = self.center_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1), 
                                                        queries_center[1].squeeze(1), offsets[1].squeeze(1), 
                                                        queries_center[2].squeeze(1), offsets[2].squeeze(1))
                        new_offset = self.offset_sets(queries_center[0].squeeze(1), offsets[0].squeeze(1),
                                                        queries_center[1].squeeze(1), offsets[1].squeeze(1),
                                                        queries_center[2].squeeze(1), offsets[2].squeeze(1))
                    new_query_min = (new_query_center - 0.5*self.func(new_offset)).unsqueeze(1)
                    new_query_max = (new_query_center + 0.5*self.func(new_offset)).unsqueeze(1)
                    score_offset = F.relu(new_query_min - tails[0]) + F.relu(tails[0] - new_query_max)
                    score_center = new_query_center.unsqueeze(1) - tails[0]
                    score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tails[0])) - new_query_center.unsqueeze(1)
                elif 'union' in qtype:
                    new_query_min = torch.stack(queries_min, dim=0)
                    new_query_max = torch.stack(queries_max, dim=0)
                    new_query_center = torch.stack(queries_center, dim=0)
                    score_offset = F.relu(new_query_min - tails[0]) + F.relu(tails[0] - new_query_max)
                    score_center = new_query_center - tails[0]
                    score_center_plus = torch.min(new_query_max, torch.max(new_query_min, tails[0])) - new_query_center
                else:
                    assert False, 'qtype not exists: %s'%qtype
        score = self.gamma.item() - torch.norm(score_offset, p=1, dim=-1)  
        score_center = self.gamma2.item() - torch.norm(score_center, p=1, dim=-1)  
        score_center_plus = self.gamma.item() - torch.norm(score_offset, p=1, dim=-1) - self.cen * torch.norm(score_center_plus, p=1, dim=-1)
        if 'union' in qtype:
            score = torch.max(score, dim=0)[0]
            score_center = torch.max(score_center, dim=0)[0]
            score_center_plus = torch.max(score_center_plus, dim=0)[0]

        return score, score_center, torch.mean(torch.norm(offset, p=2, dim=2).squeeze(1)), score_center_plus, None