def load_multivariate_datasets()

in src/deep_demand_forecast/data.py [0:0]


def load_multivariate_datasets(path: Path) -> TrainDatasets:
    metadata_path = path if path == Path("raw_data") else path / "metadata"
    ds = load_datasets(metadata_path, path / "train", path / "test")
    target_dim = ds.metadata.feat_static_cat[0].cardinality
    grouper_train = MultivariateGrouper(max_target_dim=target_dim)
    grouper_test = MultivariateGrouper(max_target_dim=target_dim)
    return TrainDatasets(
        metadata=ds.metadata,
        train=grouper_train(ds.train),
        test=grouper_test(ds.test),
    )