def convert_to_graph()

in src/dataset.py [0:0]


def convert_to_graph(protein: dict, k: int = 3) -> dgl.DGLGraph:
    """
    Convert a protein (dict) to a dgl graph using kNN.
    """
    coords = torch.tensor(protein["coords"])
    X_ca = coords[:, 1]
    # construct knn graph from C-alpha coordinates
    g = dgl.knn_graph(X_ca, k=k)
    seq = protein["seq"]
    node_features = torch.tensor([d1_to_index[residue] for residue in seq])
    node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(
        dtype=torch.float
    )

    # add node features
    g.ndata["h"] = node_features
    return g