in hype_kg/codes/model.py [0:0]
def __init__(self, manifold, mode_dims, expand_dims, center_use_offset, att_type, bn, nat, name="Real"):
super(Attention, self).__init__()
self.manifold = manifold
self.center_use_offset = center_use_offset
self.bn = bn
self.nat = nat
if center_use_offset:
self.atten_mats1 = ManifoldParameter(torch.FloatTensor(expand_dims*2, mode_dims), manifold=self.manifold)
else:
self.atten_mats1 = ManifoldParameter(torch.FloatTensor(expand_dims, mode_dims), manifold=self.manifold)
nn.init.xavier_uniform(self.atten_mats1)
self.register_parameter("atten_mats1_%s"%name, self.atten_mats1)
if self.nat >= 2:
self.atten_mats1_1 = ManifoldParameter(torch.FloatTensor(mode_dims, mode_dims), manifold=self.manifold)
nn.init.xavier_uniform(self.atten_mats1_1)
self.register_parameter("atten_mats1_1_%s"%name, self.atten_mats1_1)
if self.nat >= 3:
self.atten_mats1_2 = ManifoldParameter(torch.FloatTensor(mode_dims, mode_dims), manifold=self.manifold)
nn.init.xavier_uniform(self.atten_mats1_2)
self.register_parameter("atten_mats1_2_%s"%name, self.atten_mats1_2)
if bn != 'no':
self.bn1 = nn.BatchNorm1d(mode_dims)
self.bn1_1 = nn.BatchNorm1d(mode_dims)
self.bn1_2 = nn.BatchNorm1d(mode_dims)
if att_type == 'whole':
self.atten_mats2 = ManifoldParameter(torch.FloatTensor(mode_dims, 1), manifold=self.manifold)
elif att_type == 'ele':
self.atten_mats2 = ManifoldParameter(torch.FloatTensor(mode_dims, mode_dims), manifold=self.manifold)
nn.init.xavier_uniform(self.atten_mats2)
self.register_parameter("atten_mats2_%s"%name, self.atten_mats2)