in lmgvp/data_loaders.py [0:0]
def get_dataset(task="", model_type="", split="train"):
"""Load data from files, then transform into appropriate
Dataset objects.
Args:
task: one of ['cc', 'bp', 'mf', 'protease', 'flu']
model_type: one of ['seq', 'struct', 'seq_struct']
split: one of ['train', 'valid', 'test']
Return:
Torch dataset.
"""
seq_only = True if model_type == "seq" else False
tokenizer = None
if model_type != "struct":
# need to add BERT
print("Loading BertTokenizer...")
tokenizer = BertTokenizer.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False
)
# Load data from files
if task in ("cc", "bp", "mf"): # GO dataset
# load labels
prot2annot, num_outputs, pos_weights = load_GO_labels(task)
# load features
dataset = load_gvp_data(
task="DeepFRI_GO", split=split, seq_only=seq_only
)
add_GO_labels(dataset, prot2annot, go_ont=task)
else:
data_dir = {"protease": "protease/with_tags", "flu": "Fluorescence"}
dataset = load_gvp_data(
task=data_dir[task], split=split, seq_only=seq_only
)
num_outputs = 1
pos_weights = None
# Convert data into Dataset objects
if model_type == "seq":
if num_outputs == 1:
targets = torch.tensor(
[obj["target"] for obj in dataset], dtype=torch.float32
).unsqueeze(-1)
else:
targets = [obj["target"] for obj in dataset]
dataset = SequenceDatasetWithTarget(
[obj["seq"] for obj in dataset],
targets,
tokenizer=tokenizer,
preprocess=True,
)
else:
if num_outputs == 1:
# convert target to f32 [1] tensor
for obj in dataset:
obj["target"] = torch.tensor(
obj["target"], dtype=torch.float32
).unsqueeze(-1)
if model_type == "struct":
dataset = ProteinGraphDatasetWithTarget(dataset, preprocess=False)
elif model_type == "seq_struct":
dataset = preprocess_seqs(tokenizer, dataset)
dataset = BertProteinGraphDatasetWithTarget(
dataset, preprocess=False
)
dataset.num_outputs = num_outputs
dataset.pos_weights = pos_weights
return dataset