in models.py [0:0]
def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None):
z_emb = self.z_embedding(z)
if z_emb.ndim == 3: # in case z has multiple integer labels
z_emb = z_emb.sum(dim=1)
if self.use_feature and x is not None:
x = torch.cat([z_emb, x.to(torch.float)], 1)
else:
x = z_emb
if self.node_embedding is not None and node_id is not None:
n_emb = self.node_embedding(node_id)
x = torch.cat([x, n_emb], 1)
xs = [x]
for conv in self.convs:
xs += [torch.tanh(conv(xs[-1], edge_index, edge_weight))]
x = torch.cat(xs[1:], dim=-1)
# Global pooling.
x = global_sort_pool(x, batch, self.k)
x = x.unsqueeze(1) # [num_graphs, 1, k * hidden]
x = F.relu(self.conv1(x))
x = self.maxpool1d(x)
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # [num_graphs, dense_dim]
# MLP.
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return x