in hugegraph-ml/src/hugegraph_ml/models/diffpool.py [0:0]
def forward(self, g, feat):
self.link_pred_loss = []
self.entropy_loss = []
h = feat
# node feature for assignment matrix computation is the same as the
# original node feature
out_all = []
# we use GCN blocks to get an embedding first
g_embedding = _gcn_forward(g, h, self.gc_before_pool, self.concat)
g.ndata["h"] = g_embedding
readout = dgl.sum_nodes(g, "h")
out_all.append(readout)
if self.num_aggs == 2:
readout = dgl.max_nodes(g, "h")
out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / len(g.batch_num_nodes()))
h, adj = _batch2tensor(adj, h, node_per_pool_graph)
h = _gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj)
h = _gcn_forward_tensorized(h, adj, self.gc_after_pool[i + 1], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
if self.concat or self.num_aggs > 1:
final_readout = torch.cat(out_all, dim=1)
else:
final_readout = readout
ypred = self.pred_layer(final_readout)
return ypred