in src/deep_baselines/run.py [0:0]
def load_dataset(args, dataset_type, encode_func, encode_func_args):
'''
load dataset
:param args:
:param dataset_type:
:param encode_func: encode function
:param encode_func_args: encode function args
:return:
'''
x = []
y = []
lens = []
if os.path.exists(args.label_filepath):
label_list = load_labels(args.label_filepath, header=True)
else:
label_list = load_labels(os.path.join(args.data_dir, "label.txt"), header=True)
label_map = {name: idx for idx, name in enumerate(label_list)}
npz_filpath = os.path.join(args.data_dir, "%s_%s_%s.npz" % (dataset_type, args.model_type, str(args.one_hot_encode)))
if os.path.exists(npz_filpath):
npzfile = np.load(npz_filpath, allow_pickle=True)
x = npzfile["x"]
y = npzfile["y"]
lens = npzfile["lens"]
else:
cnt = 0
if args.filename_pattern:
filepath = os.path.join(args.data_dir, args.filename_pattern.format(dataset_type))
else:
filepath = os.path.join(args.data_dir, "%s_with_pdb_emb.csv" % dataset_type)
header = False
header_filter = False
if filepath.endswith(".csv"):
header = True
header_filter = True
for row in file_reader(filepath, header=header, header_filter=header_filter):
prot_id, seq, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row
encode_func_args["seq"] = seq.upper()
seq_ids, actural_len = encode_func(**encode_func_args)
if args.task_type in ["multi-class", "multi_class"]:
label = label_map[label]
elif args.task_type == "regression":
label = float(label)
elif args.task_type in ["multi-label", "multi_label"]:
if isinstance(label, str):
label = [0] * len(label_map)
for label_name in eval(label):
label_id = label_map[label_name]
label[label_id] = 1
else:
label = [0] * len(label_map)
for label_name in label:
label_id = label_map[label_name]
label[label_id] = 1
elif args.task_type in ["binary-class", "binary_class"]:
label = label_map[label]
x.append(seq_ids)
y.append(label)
lens.append(actural_len)
cnt += 1
if cnt % 10000 == 0:
print("done %d" % cnt)
x = np.array(x)
y = np.array(y)
lens = np.array(lens)
np.savez(npz_filpath, x=x, y=y, lens=lens)
print("%s: x.shape: %s, y.shape: %s, lens.shape: %s" %(dataset_type, str(x.shape), str(y.shape), str(lens.shape)))
return torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long), torch.tensor(lens, dtype=torch.long)), label_list