in hugegraph-ml/src/hugegraph_ml/models/seal.py [0:0]
def forward(self, g, z, node_id=None, edge_id=None):
"""
Args:
g(DGLGraph): the graph
z(Tensor): node labeling tensor, shape [N, 1]
node_id(Tensor, optional): node id tensor, shape [N, 1]
edge_id(Tensor, optional): edge id tensor, shape [E, 1]
Returns:
x(Tensor): output tensor
"""
z_emb = self.z_embedding(z)
if self.use_attribute:
x = self.node_attributes_lookup(node_id)
x = torch.cat([z_emb, x], 1)
else:
x = z_emb
if self.use_edge_weight:
edge_weight = self.edge_weights_lookup(edge_id)
else:
edge_weight = None
if self.use_embedding:
n_emb = self.node_embedding(node_id)
x = torch.cat([x, n_emb], 1)
for layer in self.layers[:-1]:
x = layer(g, x, edge_weight=edge_weight)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.layers[-1](g, x, edge_weight=edge_weight)
x = self.pooling(g, x)
x = F.relu(self.linear_1(x))
F.dropout(x, p=self.dropout, training=self.training)
x = self.linear_2(x)
return x