in lmgvp/modules.py [0:0]
def _forward(self, batch):
"""Does the forward pass through the model for batch
Args:
batch: torch_geometric.data.Data
Returns:
Inferenced logits
"""
edge_index = batch.edge_index
batch_size = batch.num_graphs
input_ids = batch.input_ids.reshape(batch_size, -1)
attention_mask = batch.attention_mask.reshape(batch_size, -1)
node_embeddings = _bert_forward(
self.bert_model, self.embeding_dim, input_ids, attention_mask
)
# GAT forward
conv1_out = self.conv1(node_embeddings, edge_index)
conv2_out = self.conv2(conv1_out, edge_index)
conv3_out = self.conv3(conv2_out, edge_index)
# residual concat
out = torch.cat((conv1_out, conv2_out, conv3_out), dim=-1)
out = self.dropout(self.relu(out)) # [n_nodes, 2048]
# aggregate node vectors to graph
out = scatter_mean(out, batch.batch, dim=0) # [bs, 2048]
return self.dense(out).squeeze(-1) + 0.5 # [bs]