in src/dataset.py [0:0]
def convert_to_graph(protein: dict, k: int = 3) -> dgl.DGLGraph:
"""
Convert a protein (dict) to a dgl graph using kNN.
"""
coords = torch.tensor(protein["coords"])
X_ca = coords[:, 1]
# construct knn graph from C-alpha coordinates
g = dgl.knn_graph(X_ca, k=k)
seq = protein["seq"]
node_features = torch.tensor([d1_to_index[residue] for residue in seq])
node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(
dtype=torch.float
)
# add node features
g.ndata["h"] = node_features
return g