def __init__()

in hype_kg/codes/model.py [0:0]


    def __init__(self, manifold, mode_dims, expand_dims, center_use_offset, att_type, bn, nat, name="Real"):
        super(Attention, self).__init__()
        
        self.manifold = manifold
        self.center_use_offset = center_use_offset
        self.bn = bn
        self.nat = nat
        if center_use_offset:
            self.atten_mats1 = ManifoldParameter(torch.FloatTensor(expand_dims*2, mode_dims), manifold=self.manifold)
        else:
            self.atten_mats1 = ManifoldParameter(torch.FloatTensor(expand_dims, mode_dims), manifold=self.manifold)
        nn.init.xavier_uniform(self.atten_mats1)
        self.register_parameter("atten_mats1_%s"%name, self.atten_mats1)
        if self.nat >= 2:
            self.atten_mats1_1 = ManifoldParameter(torch.FloatTensor(mode_dims, mode_dims), manifold=self.manifold)
            nn.init.xavier_uniform(self.atten_mats1_1)
            self.register_parameter("atten_mats1_1_%s"%name, self.atten_mats1_1)
        if self.nat >= 3:
            self.atten_mats1_2 = ManifoldParameter(torch.FloatTensor(mode_dims, mode_dims), manifold=self.manifold)
            nn.init.xavier_uniform(self.atten_mats1_2)
            self.register_parameter("atten_mats1_2_%s"%name, self.atten_mats1_2)
        if bn != 'no':
            self.bn1 = nn.BatchNorm1d(mode_dims)
            self.bn1_1 = nn.BatchNorm1d(mode_dims)
            self.bn1_2 = nn.BatchNorm1d(mode_dims)
        if att_type == 'whole':
            self.atten_mats2 = ManifoldParameter(torch.FloatTensor(mode_dims, 1), manifold=self.manifold)
        elif att_type == 'ele':
            self.atten_mats2 = ManifoldParameter(torch.FloatTensor(mode_dims, mode_dims), manifold=self.manifold)
        nn.init.xavier_uniform(self.atten_mats2)
        self.register_parameter("atten_mats2_%s"%name, self.atten_mats2)