def forward()

in hugegraph-ml/src/hugegraph_ml/models/gatne.py [0:0]


    def forward(self, block):
        input_nodes = block.srcdata[dgl.NID]
        output_nodes = block.dstdata[dgl.NID]
        batch_size = block.number_of_dst_nodes()
        node_embed = self.node_embeddings
        node_type_embed = []

        with block.local_scope():
            for i in range(self.edge_type_count):
                edge_type = self.edge_types[i]
                block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i]
                block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i]
                block.update_all(
                    fn.copy_u(edge_type, "m"),
                    fn.sum("m", edge_type), # pylint: disable=E1101
                    etype=edge_type,
                )
                node_type_embed.append(block.dstdata[edge_type])

            node_type_embed = torch.stack(node_type_embed, 1)
            tmp_node_type_embed = node_type_embed.unsqueeze(2).view(
                -1, 1, self.embedding_u_size
            )
            trans_w = (
                self.trans_weights.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.embedding_u_size, self.embedding_size)
            )
            trans_w_s1 = (
                self.trans_weights_s1.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.embedding_u_size, self.dim_a)
            )
            trans_w_s2 = (
                self.trans_weights_s2.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.dim_a, 1)
            )

            attention = (
                F.softmax(
                    torch.matmul(
                        torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)),
                        trans_w_s2,
                    )
                    .squeeze(2)
                    .view(-1, self.edge_type_count),
                    dim=1,
                )
                .unsqueeze(1)
                .repeat(1, self.edge_type_count, 1)
            )

            node_type_embed = torch.matmul(attention, node_type_embed).view(
                -1, 1, self.embedding_u_size
            )
            node_embed = node_embed[output_nodes].unsqueeze(1).repeat(
                1, self.edge_type_count, 1
            ) + torch.matmul(node_type_embed, trans_w).view(
                -1, self.edge_type_count, self.embedding_size
            )
            last_node_embed = F.normalize(node_embed, dim=2)

            return last_node_embed  # [batch_size, edge_type_count, embedding_size]