def get_mixed_pattern()

in src/gluonts/nursery/SCott/dataset_tools/synthetic.py [0:0]


def get_mixed_pattern(unit_length=16, num_duplicates=1000):
    freq = "1H"
    context_length = 3 * unit_length
    prediction_length = unit_length
    len_sample = context_length + prediction_length

    dataset_group = [[] for j in range(16)]
    whole_data = []
    val_data = []
    ret = dict()
    start = pd.Timestamp("01-01-2000", freq=freq)
    patterns = [
        ["sin", "linear", "quadratic", "sqrt"],
        ["sqrt", "quadratic", "linear", "sin"],
        ["linear", "sqrt", "sin", "quadratic"],
        ["quadratic", "sin", "sqrt", "linear"],
    ]
    pattern_number = 4
    for m, pattern in enumerate(patterns):
        for gid in range(pattern_number):
            for j in range(num_duplicates):
                context = torch.arange(context_length, dtype=torch.float)
                for i in range(1, pattern_number):
                    context[
                        unit_length * (i - 1) : unit_length * i
                    ] = _get_mixed_pattern(
                        context[unit_length * (i - 1) : unit_length * i]
                        - unit_length * (i - 1),
                        pattern[(gid + i) % pattern_number],
                    )
                ts_sample = torch.cat(
                    [
                        context,
                        _get_mixed_pattern(
                            torch.arange(prediction_length, dtype=torch.float),
                            pattern[gid],
                        ),
                    ]
                )
                whole_data.append({"target": ts_sample, "start": start})
                if j % 5 == 0:
                    val_data.append(
                        {
                            "target": ts_sample
                            + torch.normal(0, 1, ts_sample.shape),
                            "start": start,
                        }
                    )
                dataset_group[m * 4 + gid].append(
                    {"target": ts_sample, "start": start}
                )
    print(len(whole_data))
    print(len(val_data))
    ret["group_ratio"] = [len(i) / len(whole_data) for i in dataset_group]
    print(ret["group_ratio"])
    random.shuffle(whole_data)
    group_data = []
    ret["whole_data"] = ListDataset(whole_data, freq=freq)
    ret["val_data"] = ListDataset(val_data, freq=freq)
    for group in dataset_group:
        random.shuffle(group)
        group_data.append(ListDataset(group, freq=freq))
    ret["group_data"] = group_data

    # save to files
    with open("../dataset/mix.csv", "wb") as output:
        pickle.dump(ret, output)

    return True