in hype_kg/codes/model.py [0:0]
def forward(self, embeds1, embeds1_o, embeds2, embeds2_o, embeds3 = [], embeds3_o=[]):
if self.center_use_offset:
temp1 = torch.cat([embeds1, embeds1_o], dim=1)
temp2 = torch.cat([embeds2, embeds2_o], dim=1)
if len(embeds3) > 0:
embeds3 = embeds3
if embeds3_o is not None: embeds3_o = embeds3_o
temp3 = torch.cat([embeds3, embeds3_o], dim=1)
else:
temp1 = embeds1
temp2 = embeds2
if len(embeds3) > 0:
temp3 = embeds3
if self.bn == 'no':
temp1 = F.relu(temp1.mm(self.pre_mats))
temp2 = F.relu(temp2.mm(self.pre_mats))
elif self.bn == 'before':
temp1 = F.relu(self.bn1(temp1.mm(self.pre_mats)))
temp2 = F.relu(self.bn2(temp2.mm(self.pre_mats)))
elif self.bn == 'after':
temp1 = self.bn1(F.relu(temp1.mm(self.pre_mats)))
temp2 = self.bn2(F.relu(temp2.mm(self.pre_mats)))
if len(embeds3) > 0:
if self.bn == 'no':
temp3 = F.relu(temp3.mm(self.pre_mats))
elif self.bn == 'before':
temp3 = F.relu(self.bn3(temp3.mm(self.pre_mats)))
elif self.bn == 'after':
temp3 = self.bn3(F.relu(temp3.mm(self.pre_mats)))
combined = torch.stack([temp1, temp2, temp3])
else:
combined = torch.stack([temp1, temp2])
combined = self.agg_func(combined, dim=0)
if type(combined) == tuple:
combined = combined[0]
combined = combined.mm(self.post_mats)
return combined