in hugegraph-ml/src/hugegraph_ml/models/bgnn.py [0:0]
def forward(self, graph, features):
h = features
logits = None
if self.use_mlp:
if self.join_with_mlp:
h = torch.cat((h, self.mlp(features)), 1)
else:
h = self.mlp(features)
if self.name == "gat":
h = self.l1(graph, h).flatten(1)
logits = self.l2(graph, h).mean(1)
elif self.name in ["appnp"]:
h = self.lin1(h)
logits = self.l1(graph, h)
elif self.name == "agnn":
h = self.lin1(h)
h = self.l1(graph, h)
h = self.l2(graph, h)
logits = self.lin2(h)
elif self.name == "che3b":
lambda_max = dgl.laplacian_lambda_max(graph)
h = self.drop(h)
h = self.l1(graph, h, lambda_max)
logits = self.l2(graph, h, lambda_max)
elif self.name == "gcn":
h = self.drop(h)
h = self.l1(graph, h)
logits = self.l2(graph, h)
return logits