def load_dataset()

in src/baselines/dnn.py [0:0]


def load_dataset(args, dataset_type):
    '''
    :param args:
    :param dataset_type:
    :return:
    '''
    x = []
    y = []
    if os.path.exists(args.label_filepath):
        label_list = load_labels(args.label_filepath, header=True)
    else:
        label_list = load_labels(os.path.join(args.data_dir, "label.txt"), header=True)
    label_map = {name: idx for idx, name in enumerate(label_list)}
    npz_filpath = os.path.join(args.data_dir, "%s_emb.npz" % dataset_type)
    if os.path.exists(npz_filpath):
        npzfile = np.load(npz_filpath, allow_pickle=True)
        x = npzfile["x"]
        y = npzfile["y"]
    else:
        cnt = 0
        if args.filename_pattern:
            filepath = os.path.join(args.data_dir, args.filename_pattern.format(dataset_type))
        else:
            filepath = os.path.join(args.data_dir, "%s_with_pdb_emb.csv" % dataset_type)
        header = False
        header_filter = False
        if filepath.endswith(".csv"):
            header = True
            header_filter = True
        for row in file_reader(filepath, header=header, header_filter=header_filter):
            prot_id, seq, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row
            embedding_filepath = os.path.join(args.data_dir, "embs", emb_filename)
            if os.path.exists(embedding_filepath):
                emb = torch.load(embedding_filepath)
                embedding_info = emb["bos_representations"][36].numpy()
                x.append(embedding_info)
                if args.task_type in ["multi-class", "multi_class"]:
                    label = label_map[label]
                elif args.task_type == "regression":
                    label = float(label)
                elif args.task_type in ["multi-label", "multi_label"]:
                    if isinstance(label, str):
                        label = [0] * len(label_map)
                        for label_name in eval(label):
                            label_id = label_map[label_name]
                            label[label_id] = 1
                    else:
                        label = [0] * len(label_map)
                        for label_name in label:
                            label_id = label_map[label_name]
                            label[label_id] = 1
                elif args.task_type in ["binary-class", "binary_class"]:
                    label = label_map[label]
                y.append(label)
                cnt += 1
                if cnt % 10000 == 0:
                    print("done %d" % cnt)
        x = np.array(x)
        y = np.array(y)
        np.savez(npz_filpath, x=x, y=y)
        print("%s: x.shape: %s, y.shape: %s" %(dataset_type, str(x.shape), str(y.shape)))
    return torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)), label_list