def forward()

in hype_kg/codes/model.py [0:0]


    def forward(self, embeds1, embeds1_o, embeds2, embeds2_o, embeds3=[], embeds3_o=[]):
        temp1 = (self.Attention_module(embeds1, embeds1_o) + self.att_reg)/(self.att_tem+1e-4)
        temp2 = (self.Attention_module(embeds2, embeds2_o) + self.att_reg)/(self.att_tem+1e-4)
        if len(embeds3) > 0:
            temp3 = (self.Attention_module(embeds3, embeds3_o) + self.att_reg)/(self.att_tem+1e-4)
            if self.att_type == 'whole':
                combined = F.softmax(torch.cat([temp1, temp2, temp3], dim=1), dim=1)
                center = embeds1*(combined[:,0].view(embeds1.size(0), 1)) + \
                        embeds2*(combined[:,1].view(embeds2.size(0), 1)) + \
                        embeds3*(combined[:,2].view(embeds3.size(0), 1))
            elif self.att_type == 'ele':
                combined = F.softmax(torch.stack([temp1, temp2, temp3]), dim=0)
                center = embeds1*combined[0] + embeds2*combined[1] + embeds3*combined[2]
        else:
            if self.att_type == 'whole':
                combined = F.softmax(torch.cat([temp1, temp2], dim=1), dim=1)
                center = embeds1*(combined[:,0].view(embeds1.size(0), 1)) + \
                        embeds2*(combined[:,1].view(embeds2.size(0), 1))
            elif self.att_type == 'ele':
                combined = F.softmax(torch.stack([temp1, temp2]), dim=0)
                center = embeds1*combined[0] + embeds2*combined[1]

        return center