def _forward()

in lmgvp/modules.py [0:0]


    def _forward(self, batch):
        """Does the forward pass through the model for batch

        Args:
            batch: torch_geometric.data.Data

        Returns:
            Inferenced logits

        """
        edge_index = batch.edge_index
        batch_size = batch.num_graphs
        input_ids = batch.input_ids.reshape(batch_size, -1)
        attention_mask = batch.attention_mask.reshape(batch_size, -1)

        node_embeddings = _bert_forward(
            self.bert_model, self.embeding_dim, input_ids, attention_mask
        )
        # GAT forward
        conv1_out = self.conv1(node_embeddings, edge_index)
        conv2_out = self.conv2(conv1_out, edge_index)
        conv3_out = self.conv3(conv2_out, edge_index)
        # residual concat
        out = torch.cat((conv1_out, conv2_out, conv3_out), dim=-1)
        out = self.dropout(self.relu(out))  # [n_nodes, 2048]
        # aggregate node vectors to graph
        out = scatter_mean(out, batch.batch, dim=0)  # [bs, 2048]
        return self.dense(out).squeeze(-1) + 0.5  # [bs]