def load_dataset()

in src/deep_baselines/run.py [0:0]


def load_dataset(args, dataset_type, encode_func, encode_func_args):
    '''
    load dataset
    :param args:
    :param dataset_type:
    :param encode_func: encode function
    :param encode_func_args: encode function args
    :return:
    '''
    x = []
    y = []
    lens = []
    if os.path.exists(args.label_filepath):
        label_list = load_labels(args.label_filepath, header=True)
    else:
        label_list = load_labels(os.path.join(args.data_dir, "label.txt"), header=True)
    label_map = {name: idx for idx, name in enumerate(label_list)}
    npz_filpath = os.path.join(args.data_dir, "%s_%s_%s.npz" % (dataset_type, args.model_type, str(args.one_hot_encode)))
    if os.path.exists(npz_filpath):
        npzfile = np.load(npz_filpath, allow_pickle=True)
        x = npzfile["x"]
        y = npzfile["y"]
        lens = npzfile["lens"]
    else:
        cnt = 0
        if args.filename_pattern:
            filepath = os.path.join(args.data_dir, args.filename_pattern.format(dataset_type))
        else:
            filepath = os.path.join(args.data_dir, "%s_with_pdb_emb.csv" % dataset_type)
        header = False
        header_filter = False
        if filepath.endswith(".csv"):
            header = True
            header_filter = True
        for row in file_reader(filepath, header=header, header_filter=header_filter):
            prot_id, seq, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row
            encode_func_args["seq"] = seq.upper()
            seq_ids, actural_len = encode_func(**encode_func_args)
            if args.task_type in ["multi-class", "multi_class"]:
                label = label_map[label]
            elif args.task_type == "regression":
                label = float(label)
            elif args.task_type in ["multi-label", "multi_label"]:
                if isinstance(label, str):
                    label = [0] * len(label_map)
                    for label_name in eval(label):
                        label_id = label_map[label_name]
                        label[label_id] = 1
                else:
                    label = [0] * len(label_map)
                    for label_name in label:
                        label_id = label_map[label_name]
                        label[label_id] = 1
            elif args.task_type in ["binary-class", "binary_class"]:
                label = label_map[label]
            x.append(seq_ids)
            y.append(label)
            lens.append(actural_len)
            cnt += 1
            if cnt % 10000 == 0:
                print("done %d" % cnt)
        x = np.array(x)
        y = np.array(y)
        lens = np.array(lens)
        np.savez(npz_filpath, x=x, y=y, lens=lens)
    print("%s: x.shape: %s, y.shape: %s, lens.shape: %s" %(dataset_type, str(x.shape), str(y.shape), str(lens.shape)))

    return torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long), torch.tensor(lens, dtype=torch.long)), label_list