in hype_kg/codes/model.py [0:0]
def forward(self, embeds1, embeds1_o, embeds2, embeds2_o, embeds3=[], embeds3_o=[]):
temp1 = (self.Attention_module(embeds1, embeds1_o) + self.att_reg)/(self.att_tem+1e-4)
temp2 = (self.Attention_module(embeds2, embeds2_o) + self.att_reg)/(self.att_tem+1e-4)
if len(embeds3) > 0:
temp3 = (self.Attention_module(embeds3, embeds3_o) + self.att_reg)/(self.att_tem+1e-4)
if self.att_type == 'whole':
combined = F.softmax(torch.cat([temp1, temp2, temp3], dim=1), dim=1)
center = embeds1*(combined[:,0].view(embeds1.size(0), 1)) + \
embeds2*(combined[:,1].view(embeds2.size(0), 1)) + \
embeds3*(combined[:,2].view(embeds3.size(0), 1))
elif self.att_type == 'ele':
combined = F.softmax(torch.stack([temp1, temp2, temp3]), dim=0)
center = embeds1*combined[0] + embeds2*combined[1] + embeds3*combined[2]
else:
if self.att_type == 'whole':
combined = F.softmax(torch.cat([temp1, temp2], dim=1), dim=1)
center = embeds1*(combined[:,0].view(embeds1.size(0), 1)) + \
embeds2*(combined[:,1].view(embeds2.size(0), 1))
elif self.att_type == 'ele':
combined = F.softmax(torch.stack([temp1, temp2]), dim=0)
center = embeds1*combined[0] + embeds2*combined[1]
return center