in viz_dataset.py [0:0]
def load_data(data, split="train"):
if data == "citibike":
return datasets.Citibike(split=split)
elif data == "covid_nj_cases":
return datasets.CovidNJ(split=split)
elif data == "earthquakes_jp":
return datasets.Earthquakes(split=split)
elif data == "pinwheel":
return toy_datasets.PinwheelHawkes(split=split)
elif data == "gmm":
return toy_datasets.GMMHawkes(split=split)
elif data == "fmri":
return datasets.BOLD5000(split=split)
else:
raise ValueError(f"Unknown data option {data}")