def load_dataset()

in src/preprocess/data_preprocessor.py [0:0]


def load_dataset(dataset_name: str, path: Path) -> TrainDatasets:
    dataset = get_dataset(dataset_name, path, regenerate=False)
    target_dim = dataset.metadata.feat_static_cat[0].cardinality
    grouper_train = MultivariateGrouper(max_target_dim=target_dim)
    grouper_test = MultivariateGrouper(max_target_dim=target_dim)
    return TrainDatasets(
        metadata=dataset.metadata,
        train=grouper_train(dataset.train),
        test=grouper_test(dataset.test),
    )