in hype_kg/codes/model.py [0:0]
def forward(self, center_embed, offset_embed=None):
if self.center_use_offset:
temp1 = torch.cat([center_embed, offset_embed], dim=1)
else:
temp1 = center_embed
if self.nat >= 1:
if self.bn == 'no':
temp2 = F.relu(temp1.mm(self.atten_mats1))
elif self.bn == 'before':
temp2 = F.relu(self.bn1(temp1.mm(self.atten_mats1)))
elif self.bn == 'after':
temp2 = self.bn1(F.relu(temp1.mm(self.atten_mats1)))
if self.nat >= 2:
if self.bn == 'no':
temp2 = F.relu(temp2.mm(self.atten_mats1_1))
elif self.bn == 'before':
temp2 = F.relu(self.bn1_1(temp2.mm(self.atten_mats1_1)))
elif self.bn == 'after':
temp2 = self.bn1_1(F.relu(temp2.mm(self.atten_mats1_1)))
if self.nat >= 3:
if self.bn == 'no':
temp2 = F.relu(temp2.mm(self.atten_mats1_2))
elif self.bn == 'before':
temp2 = F.relu(self.bn1_2(temp2.mm(self.atten_mats1_2)))
elif self.bn == 'after':
temp2 = self.bn1_2(F.relu(temp2.mm(self.atten_mats1_2)))
temp3 = temp2.mm(self.atten_mats2)
return temp3