in models.py [0:0]
def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None):
z_emb = self.z_embedding(z)
if z_emb.ndim == 3: # in case z has multiple integer labels
z_emb = z_emb.sum(dim=1)
if self.use_feature and x is not None:
x = torch.cat([z_emb, x.to(torch.float)], 1)
else:
x = z_emb
if self.node_embedding is not None and node_id is not None:
n_emb = self.node_embedding(node_id)
x = torch.cat([x, n_emb], 1)
for conv in self.convs[:-1]:
x = conv(x, edge_index, edge_weight)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index, edge_weight)
if True: # center pooling
_, center_indices = np.unique(batch.cpu().numpy(), return_index=True)
x_src = x[center_indices]
x_dst = x[center_indices + 1]
x = (x_src * x_dst)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin2(x)
else: # sum pooling
x = global_add_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin2(x)
return x