def forward()

in hugegraph-ml/src/hugegraph_ml/models/seal.py [0:0]


    def forward(self, g, z, node_id=None, edge_id=None):
        """
        Args:
            g(DGLGraph): the graph
            z(Tensor): node labeling tensor, shape [N, 1]
            node_id(Tensor, optional): node id tensor, shape [N, 1]
            edge_id(Tensor, optional): edge id tensor, shape [E, 1]
        Returns:
            x(Tensor): output tensor
        """
        z_emb = self.z_embedding(z)
        if self.use_attribute:
            x = self.node_attributes_lookup(node_id)
            x = torch.cat([z_emb, x], 1)
        else:
            x = z_emb
        if self.use_edge_weight:
            edge_weight = self.edge_weights_lookup(edge_id)
        else:
            edge_weight = None

        if self.use_embedding:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)

        xs = [x]
        for layer in self.layers:
            out = torch.tanh(layer(g, xs[-1], edge_weight=edge_weight))
            xs += [out]

        x = torch.cat(xs[1:], dim=-1)

        # SortPooling
        x = self.pooling(g, x)
        x = x.unsqueeze(1)
        x = F.relu(self.conv_1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv_2(x))
        x = x.view(x.size(0), -1)

        x = F.relu(self.linear_1(x))
        F.dropout(x, p=self.dropout, training=self.training)
        x = self.linear_2(x)

        return x