def _forward()

in lmgvp/modules.py [0:0]


    def _forward(self, batch, input_ids=None):
        """
        Helper function to perform the forward pass.

        Args:
            batch: torch_geometric.data.Data
            input_ids: IDs of the embeddings to be used in the model.
        Returns:
            logits
        """
        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index

        batch_size = batch.num_graphs

        if input_ids is None:
            input_ids = batch.input_ids.reshape(batch_size, -1)
        attention_mask = batch.attention_mask.reshape(batch_size, -1)

        node_embeddings = _bert_forward(
            self.bert_model, self.embeding_dim, input_ids, attention_mask
        )
        node_embeddings = self.identity(node_embeddings)

        h_V = (torch.cat([h_V[0], node_embeddings], dim=-1), h_V[1])

        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        if not self.residual:
            for layer in self.layers:
                h_V = layer(h_V, edge_index, h_E)
            out = self.W_out(h_V)
        else:
            h_V_out = []  # collect outputs from GVPConvLayers
            h_V_in = h_V
            for layer in self.layers:
                h_V_out.append(layer(h_V_in, edge_index, h_E))
                h_V_in = h_V_out[-1]
            # concat outputs from GVPConvLayers (separatedly for s and V)
            h_V_out = (
                torch.cat([h_V[0] for h_V in h_V_out], dim=-1),
                torch.cat([h_V[1] for h_V in h_V_out], dim=-2),
            )
            out = self.W_out(h_V_out)

        out = scatter_mean(out, batch.batch, dim=0)
        return self.dense(out).squeeze(-1) + 0.5