in hype_kg/codes/model.py [0:0]
def forward(self, embeds1, embeds1_o, embeds2, embeds2_o, embeds3 = [], embeds3_o=[]):
if self.offset_use_center:
temp1 = torch.cat([embeds1, embeds1_o], dim=1)
temp2 = torch.cat([embeds2, embeds2_o], dim=1)
if len(embeds3_o) > 0:
temp3 = torch.cat([embeds3, embeds3_o], dim=1)
else:
temp1 = embeds1_o
temp2 = embeds2_o
if len(embeds3_o) > 0:
temp3 = embeds3_o
temp1 = F.relu(temp1.mm(self.pre_mats))
temp2 = F.relu(temp2.mm(self.pre_mats))
if len(embeds3_o) > 0:
temp3 = F.relu(temp3.mm(self.pre_mats))
combined = torch.stack([temp1, temp2, temp3])
else:
combined = torch.stack([temp1, temp2])
combined = self.agg_func(combined, dim=0)
if type(combined) == tuple:
combined = combined[0]
combined = combined.mm(self.post_mats)
return combined