in src/preprocess/data_preprocessor.py [0:0]
def load_dataset(dataset_name: str, path: Path) -> TrainDatasets:
dataset = get_dataset(dataset_name, path, regenerate=False)
target_dim = dataset.metadata.feat_static_cat[0].cardinality
grouper_train = MultivariateGrouper(max_target_dim=target_dim)
grouper_test = MultivariateGrouper(max_target_dim=target_dim)
return TrainDatasets(
metadata=dataset.metadata,
train=grouper_train(dataset.train),
test=grouper_test(dataset.test),
)