in hype_kg/codes/model.py [0:0]
def forward(self, embeds1, embeds2, embeds3 = [], name='real'):
if name == 'real':
temp1 = F.relu(embeds1.mm(self.pre_mats))
temp2 = F.relu(embeds2.mm(self.pre_mats))
if len(embeds3) > 0:
temp3 = F.relu(embeds3.mm(self.pre_mats))
combined = torch.stack([temp1, temp2, temp3])
else:
combined = torch.stack([temp1, temp2])
combined = self.agg_func(combined, dim=0)
if type(combined) == tuple:
combined = combined[0]
combined = combined.mm(self.post_mats)
elif name == 'img':
temp1 = F.relu(embeds1.mm(self.pre_mats_im))
temp2 = F.relu(embeds2.mm(self.pre_mats_im))
if len(embeds3) > 0:
temp3 = F.relu(embeds3.mm(self.pre_mats_im))
combined = torch.stack([temp1, temp2, temp3])
else:
combined = torch.stack([temp1, temp2])
combined = self.agg_func(combined, dim=0)
if type(combined) == tuple:
combined = combined[0]
combined = combined.mm(self.post_mats_im)
return combined