in lmgvp/modules.py [0:0]
def _forward(self, batch, input_ids=None):
"""
Helper function to perform the forward pass.
Args:
batch: torch_geometric.data.Data
input_ids: IDs of the embeddings to be used in the model.
Returns:
logits
"""
h_V = (batch.node_s, batch.node_v)
h_E = (batch.edge_s, batch.edge_v)
edge_index = batch.edge_index
batch_size = batch.num_graphs
if input_ids is None:
input_ids = batch.input_ids.reshape(batch_size, -1)
attention_mask = batch.attention_mask.reshape(batch_size, -1)
node_embeddings = _bert_forward(
self.bert_model, self.embeding_dim, input_ids, attention_mask
)
node_embeddings = self.identity(node_embeddings)
h_V = (torch.cat([h_V[0], node_embeddings], dim=-1), h_V[1])
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
if not self.residual:
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
else:
h_V_out = [] # collect outputs from GVPConvLayers
h_V_in = h_V
for layer in self.layers:
h_V_out.append(layer(h_V_in, edge_index, h_E))
h_V_in = h_V_out[-1]
# concat outputs from GVPConvLayers (separatedly for s and V)
h_V_out = (
torch.cat([h_V[0] for h_V in h_V_out], dim=-1),
torch.cat([h_V[1] for h_V in h_V_out], dim=-2),
)
out = self.W_out(h_V_out)
out = scatter_mean(out, batch.batch, dim=0)
return self.dense(out).squeeze(-1) + 0.5