in src/gluonts/nursery/SCott/preprocess_data.py [0:0]
def get_mixed_pattern(unit_length=16, num_duplicates=1000):
freq = "1H"
context_length = 3 * unit_length
prediction_length = unit_length
len_sample = context_length + prediction_length
dataset_group = [[] for j in range(16)]
whole_data = []
val_data = []
ret = dict()
start = pd.Timestamp("01-01-2000", freq=freq)
patterns = [
["sin", "linear", "quadratic", "sqrt"],
["sqrt", "quadratic", "linear", "sin"],
["linear", "sqrt", "sin", "quadratic"],
["quadratic", "sin", "sqrt", "linear"],
]
pattern_number = 4
for m, pattern in enumerate(patterns):
for gid in range(pattern_number):
for j in range(num_duplicates):
context = torch.arange(context_length, dtype=torch.float)
for i in range(1, pattern_number):
context[
unit_length * (i - 1) : unit_length * i
] = _get_mixed_pattern(
context[unit_length * (i - 1) : unit_length * i]
- unit_length * (i - 1),
pattern[(gid + i) % pattern_number],
)
ts_sample = torch.cat(
[
context,
_get_mixed_pattern(
torch.arange(prediction_length, dtype=torch.float),
pattern[gid],
),
]
)
whole_data.append({"target": ts_sample, "start": start})
if j % 5 == 0:
val_data.append(
{
"target": ts_sample
+ torch.normal(0, 1, ts_sample.shape),
"start": start,
}
)
dataset_group[m * 4 + gid].append(
{"target": ts_sample, "start": start}
)
print(
"Generating the synthetic training data, the total number of training examples:",
len(whole_data),
)
ret["group_ratio"] = [len(i) / len(whole_data) for i in dataset_group]
random.shuffle(whole_data)
group_data = []
ret["whole_data"] = ListDataset(whole_data, freq=freq)
ret["val_data"] = ListDataset(val_data, freq=freq)
for group in dataset_group:
random.shuffle(group)
group_data.append(ListDataset(group, freq=freq))
ret["group_data"] = group_data
# save to files
os.makedirs("./dataset", exist_ok=True)
with open("./dataset/synthetic.csv", "wb") as output:
pickle.dump(ret, output)
print("Finished the pre-processing of synthetic dataset")
return True