def generate_lstnet_dataset()

in src/gluonts/nursery/SCott/pts/dataset/repository/_lstnet.py [0:0]


def generate_lstnet_dataset(dataset_path: Path, dataset_name: str):
    ds_info = datasets_info[dataset_name]

    os.makedirs(dataset_path, exist_ok=True)

    with open(dataset_path / "metadata.json", "w") as f:
        f.write(
            json.dumps(
                metadata(
                    cardinality=ds_info.num_series,
                    freq=ds_info.freq,
                    prediction_length=ds_info.prediction_length,
                )
            )
        )

    train_file = dataset_path / "train" / "data.json"
    test_file = dataset_path / "test" / "data.json"

    time_index = pd.date_range(
        start=ds_info.start_date,
        freq=ds_info.freq,
        periods=ds_info.num_time_steps,
    )

    df = pd.read_csv(ds_info.url, header=None)

    assert df.shape == (
        ds_info.num_time_steps,
        ds_info.num_series,
    ), f"expected num_time_steps/num_series {(ds_info.num_time_steps, ds_info.num_series)} but got {df.shape}"

    timeseries = load_from_pandas(
        df=df, time_index=time_index, agg_freq=ds_info.agg_freq
    )

    # the last date seen during training
    ts_index = timeseries[0].index
    training_end = ts_index[int(len(ts_index) * (8 / 10))]

    train_ts = []
    for cat, ts in enumerate(timeseries):
        sliced_ts = ts[:training_end]
        if len(sliced_ts) > 0:
            train_ts.append(
                to_dict(
                    target_values=sliced_ts.values,
                    start=sliced_ts.index[0],
                    cat=[cat],
                    item_id=cat,
                )
            )

    assert len(train_ts) == ds_info.num_series

    save_to_file(train_file, train_ts)

    # time of the first prediction
    prediction_dates = [
        frequency_add(training_end, i * ds_info.prediction_length)
        for i in range(ds_info.rolling_evaluations)
    ]

    test_ts = []
    for prediction_start_date in prediction_dates:
        for cat, ts in enumerate(timeseries):
            # print(prediction_start_date)
            prediction_end_date = frequency_add(
                prediction_start_date, ds_info.prediction_length
            )
            sliced_ts = ts[:prediction_end_date]
            test_ts.append(
                to_dict(
                    target_values=sliced_ts.values,
                    start=sliced_ts.index[0],
                    cat=[cat],
                )
            )

    assert len(test_ts) == ds_info.num_series * ds_info.rolling_evaluations

    save_to_file(test_file, test_ts)