def __init__()

in hype_kg/codes/model.py [0:0]


    def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma, 
                 writer=None, geo=None, 
                 cen=None, offset_deepsets=None,
                 center_deepsets=None, offset_use_center=None, center_use_offset=None,
                 att_reg = 0., off_reg = 0., att_tem = 1., euo = False, 
                 gamma2=0, bn='no', nat=1, activation='relu', manifold = 'poincare',
                 curvature = 1.0, trainable_curvature = False,
                 use_semantics = False):
        super(Query2Manifold, self).__init__()
        self.model_name = model_name
        self.nentity = nentity
        self.nrelation = nrelation
        self.hidden_dim = hidden_dim
        self.epsilon = 2.0
        self.writer=writer
        self.geo = geo
        self.cen = cen
        self.offset_deepsets = offset_deepsets
        self.center_deepsets = center_deepsets
        self.offset_use_center = offset_use_center
        self.center_use_offset = center_use_offset
        self.att_reg = att_reg
        self.off_reg = off_reg
        self.att_tem = att_tem
        self.euo = euo
        self.his_step = 0
        self.bn = bn
        self.nat = nat
        if activation == 'none':
            self.func = Identity
        elif activation == 'relu':
            self.func = F.relu
        elif activation == 'softplus':
            self.func = F.softplus

        if trainable_curvature:
            self.c = nn.Parameter(torch.FloatTensor([curvature]),requires_grad=True)
        else:
            self.c = nn.Parameter(torch.FloatTensor([curvature]),requires_grad=False)

        if manifold == 'poincare':
            self.manifold = poincare.PoincareBall(self.c)
        elif manifold == 'lorentz':
            self.manifold = lorentz.Lorentz(self.c)
        else:
            self.manifold = euclidean.Euclidean()

        self.gamma = ManifoldParameter(
            torch.Tensor([gamma]), 
            requires_grad=False,
            manifold = self.manifold
        )

        if gamma2 == 0:
            gamma2 = gamma

        self.gamma2 = ManifoldParameter(
            torch.Tensor([gamma2]), 
            requires_grad=False,
            manifold = self.manifold
        )
        
        
        self.embedding_range = ManifoldParameter(
            torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), 
            requires_grad=False,
            manifold = self.manifold
        )
        
        self.entity_dim = hidden_dim
        self.relation_dim = hidden_dim
        
        if use_semantics:
            #Key: Entity Id, Value: Semantic Vector
            semantic_vectors = pickle.load(open("get_vectors.pkl","rb"))
            # Semantic Vector dimensions should match the vector dimensions of box centers
            assert(len(next(iter(semantic_vectors.values()))) == hidden_dim)
            self.entity_embedding = ManifoldParameter(torch.stack([torch.Tensor(
                                                semantic_vectors[entity_id])
                                                for entity_id in range(nentity)]),
                                                manifold=self.manifold)
        else:
            self.entity_embedding = ManifoldParameter(torch.zeros(nentity, self.entity_dim), manifold=self.manifold)
        nn.init.uniform_(
            tensor=self.entity_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )
        
        self.relation_embedding = ManifoldParameter(torch.zeros(nrelation, self.relation_dim), manifold=self.manifold)
        nn.init.uniform_(
            tensor=self.relation_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )

        if self.geo == 'vec':
            if self.center_deepsets == 'vanilla':
                self.deepsets = CenterSet(self.manifold, self.relation_dim, self.relation_dim, False, agg_func = torch.mean, bn=bn, nat=nat)
            elif self.center_deepsets == 'attention':
                self.deepsets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, False, 
                                                    att_reg = self.att_reg, att_tem = self.att_tem, bn=bn, nat=nat)
            elif self.center_deepsets == 'eleattention':
                self.deepsets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, False, 
                                                    att_reg = self.att_reg, att_type='ele', att_tem=self.att_tem, bn=bn, nat=nat)
            elif self.center_deepsets == 'mean':
                self.deepsets = MeanSet(self.manifold)
            else:
                assert False

        if self.geo == 'box':
            self.offset_embedding = ManifoldParameter(torch.zeros(nrelation, self.entity_dim), manifold=self.manifold)
            nn.init.uniform_(
                tensor=self.offset_embedding, 
                a=0., 
                b=self.embedding_range.item()
            )
            if self.euo:
                self.entity_offset_embedding = ManifoldParameter(torch.zeros(nentity, self.entity_dim), manifold=self.manifold)
                nn.init.uniform_(
                    tensor=self.entity_offset_embedding, 
                    a=0., 
                    b=self.embedding_range.item()
                )

            if self.center_deepsets == 'vanilla':
                self.center_sets = CenterSet(self.manifold, self.relation_dim, self.relation_dim, self.center_use_offset, agg_func = torch.mean, bn=bn, nat=nat)
            elif self.center_deepsets == 'attention':
                self.center_sets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, self.center_use_offset, 
                                                    att_reg = self.att_reg, att_tem = self.att_tem, bn=bn, nat=nat)
            elif self.center_deepsets == 'eleattention':
                self.center_sets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, self.center_use_offset, 
                                                    att_reg = self.att_reg, att_type='ele', att_tem=self.att_tem, bn=bn, nat=nat)
            elif self.center_deepsets == 'mean':
                self.center_sets = MeanSet(self.manifold)
            else:
                assert False

            if self.offset_deepsets == 'vanilla':
                self.offset_sets = OffsetSet(self.manifold, self.relation_dim, self.relation_dim, self.offset_use_center, agg_func = torch.mean)
            elif self.offset_deepsets == 'inductive':
                self.offset_sets = InductiveOffsetSet(self.manifold, self.relation_dim, self.relation_dim, self.offset_use_center, self.off_reg, agg_func=torch.mean)
            elif self.offset_deepsets == 'min':
                self.offset_sets = MinSet(self.manifold)
            else:
                assert False

        if model_name not in ['TransE', 'BoxTransE']:
            raise ValueError('model %s not supported' % model_name)