def load_GO_labels()

in lmgvp/data_loaders.py [0:0]


def load_GO_labels(task="cc"):
    """Load the labels in the GO dataset

    Args:
        task: String. GO task. One of: 'cc', 'bp', 'mf'

    Return:
        Tuple where the first element is a dictionary mapping proteins to their target, second element is an integer with the number of outputs of the task and the third element is a matrix with the weight of each target.
    """
    prot2annot, goterms, gonames, counts = load_GO_annot(
        os.path.join(
            DATA_ROOT_DIR,
            "DeepFRI_GO_PDB/data/nrPDB-GO_2019.06.18_annot.tsv",
        )
    )
    goterms = goterms[task]
    gonames = gonames[task]
    num_outputs = len(goterms)

    # computing weights for imbalanced go classes
    class_sizes = counts[task]
    mean_class_size = np.mean(class_sizes)
    pos_weights = mean_class_size / class_sizes
    pos_weights = np.maximum(1.0, np.minimum(10.0, pos_weights))
    # to tensor
    pos_weights = torch.from_numpy(pos_weights.astype(np.float32))
    return prot2annot, num_outputs, pos_weights