def get_dataset()

in lmgvp/data_loaders.py [0:0]


def get_dataset(task="", model_type="", split="train"):
    """Load data from files, then transform into appropriate
    Dataset objects.
    Args:
        task: one of ['cc', 'bp', 'mf', 'protease', 'flu']
        model_type: one of ['seq', 'struct', 'seq_struct']
        split: one of ['train', 'valid', 'test']

    Return:
        Torch dataset.
    """
    seq_only = True if model_type == "seq" else False

    tokenizer = None
    if model_type != "struct":
        # need to add BERT
        print("Loading BertTokenizer...")
        tokenizer = BertTokenizer.from_pretrained(
            "Rostlab/prot_bert", do_lower_case=False
        )

    # Load data from files
    if task in ("cc", "bp", "mf"):  # GO dataset
        # load labels
        prot2annot, num_outputs, pos_weights = load_GO_labels(task)
        # load features
        dataset = load_gvp_data(
            task="DeepFRI_GO", split=split, seq_only=seq_only
        )
        add_GO_labels(dataset, prot2annot, go_ont=task)
    else:
        data_dir = {"protease": "protease/with_tags", "flu": "Fluorescence"}
        dataset = load_gvp_data(
            task=data_dir[task], split=split, seq_only=seq_only
        )
        num_outputs = 1
        pos_weights = None

    # Convert data into Dataset objects
    if model_type == "seq":
        if num_outputs == 1:
            targets = torch.tensor(
                [obj["target"] for obj in dataset], dtype=torch.float32
            ).unsqueeze(-1)
        else:
            targets = [obj["target"] for obj in dataset]
        dataset = SequenceDatasetWithTarget(
            [obj["seq"] for obj in dataset],
            targets,
            tokenizer=tokenizer,
            preprocess=True,
        )
    else:
        if num_outputs == 1:
            # convert target to f32 [1] tensor
            for obj in dataset:
                obj["target"] = torch.tensor(
                    obj["target"], dtype=torch.float32
                ).unsqueeze(-1)
        if model_type == "struct":
            dataset = ProteinGraphDatasetWithTarget(dataset, preprocess=False)
        elif model_type == "seq_struct":
            dataset = preprocess_seqs(tokenizer, dataset)
            dataset = BertProteinGraphDatasetWithTarget(
                dataset, preprocess=False
            )

    dataset.num_outputs = num_outputs
    dataset.pos_weights = pos_weights
    return dataset