in lmgvp/data_loaders.py [0:0]
def load_GO_labels(task="cc"):
"""Load the labels in the GO dataset
Args:
task: String. GO task. One of: 'cc', 'bp', 'mf'
Return:
Tuple where the first element is a dictionary mapping proteins to their target, second element is an integer with the number of outputs of the task and the third element is a matrix with the weight of each target.
"""
prot2annot, goterms, gonames, counts = load_GO_annot(
os.path.join(
DATA_ROOT_DIR,
"DeepFRI_GO_PDB/data/nrPDB-GO_2019.06.18_annot.tsv",
)
)
goterms = goterms[task]
gonames = gonames[task]
num_outputs = len(goterms)
# computing weights for imbalanced go classes
class_sizes = counts[task]
mean_class_size = np.mean(class_sizes)
pos_weights = mean_class_size / class_sizes
pos_weights = np.maximum(1.0, np.minimum(10.0, pos_weights))
# to tensor
pos_weights = torch.from_numpy(pos_weights.astype(np.float32))
return prot2annot, num_outputs, pos_weights