in train_stpp.py [0:0]
def get_t0_t1(data):
if data == "citibike":
return torch.tensor([0.0]), torch.tensor([24.0])
elif data == "covid_nj_cases":
return torch.tensor([0.0]), torch.tensor([7.0])
elif data == "earthquakes_jp":
return torch.tensor([0.0]), torch.tensor([30.0])
elif data == "pinwheel":
return torch.tensor([0.0]), torch.tensor([toy_datasets.END_TIME])
elif data == "gmm":
return torch.tensor([0.0]), torch.tensor([toy_datasets.END_TIME])
elif data == "fmri":
return torch.tensor([0.0]), torch.tensor([10.0])
else:
raise ValueError(f"Unknown dataset {data}")