in lmgvp/modules.py [0:0]
def _forward(self, batch):
"""Helper function to perform GVP network forward pass.
Args:
batch: torch_geometric.data.Data
Returns:
logits
"""
h_V = (batch.node_s, batch.node_v)
h_E = (batch.edge_s, batch.edge_v)
edge_index = batch.edge_index
seq = batch.seq
if seq is not None:
# one-hot encodings
seq = self.W_s(seq)
h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1])
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
# GVP Conv layers
if not self.residual:
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
else:
h_V_out = [] # collect outputs from all GVP Conv layers
h_V_in = h_V
for layer in self.layers:
h_V_out.append(layer(h_V_in, edge_index, h_E))
h_V_in = h_V_out[-1]
# concat outputs from GVPConvLayers (separatedly for s and V)
h_V_out = (
torch.cat([h_V[0] for h_V in h_V_out], dim=-1),
torch.cat([h_V[1] for h_V in h_V_out], dim=-2),
)
out = self.W_out(h_V_out)
# aggregate node vectors to graph
out = scatter_mean(out, batch.batch, dim=0)
return self.dense(out).squeeze(-1) + 0.5