in lmgvp/datasets.py [0:0]
def _featurize_as_graph(self, protein):
"""Featurizes the protein information as a graph for the GNN
Args:
protein: Dictionary with the protein seq, coord and name.
Returns:
Torch geometric data instance representing with the protein information
"""
name = protein["name"]
input_ids = protein["input_ids"]
attention_mask = protein["attention_mask"]
with torch.no_grad():
coords = torch.as_tensor(
protein["coords"], device=self.device, dtype=torch.float32
)
mask = torch.isfinite(coords.sum(dim=(1, 2)))
coords[~mask] = np.inf
X_ca = coords[:, 1]
edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k)
pos_embeddings = self._positional_embeddings(edge_index)
E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
rbf = _rbf(
E_vectors.norm(dim=-1),
D_count=self.num_rbf,
device=self.device,
)
dihedrals = self._dihedrals(coords)
orientations = self._orientations(X_ca)
sidechains = self._sidechains(coords)
node_s = dihedrals
node_v = torch.cat(
[orientations, sidechains.unsqueeze(-2)], dim=-2
)
edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
edge_v = _normalize(E_vectors).unsqueeze(-2)
node_s, node_v, edge_s, edge_v = map(
torch.nan_to_num, (node_s, node_v, edge_s, edge_v)
)
data = torch_geometric.data.Data(
x=X_ca,
input_ids=input_ids,
attention_mask=attention_mask,
name=name,
node_s=node_s,
node_v=node_v,
edge_s=edge_s,
edge_v=edge_v,
edge_index=edge_index,
mask=mask,
)
return data