in datasets.py [0:0]
def __init__(self, split="train"):
assert split in ["train", "val", "test"]
self.split = split
dataset = np.load("data/covid19/covid_nj_cases.npz")
dates = dict()
for f in dataset.files:
dates[f[:8]] = 1
dates = list(dates.keys())
# Reduce contamination between train/val/test splits.
exclude_from_train = (dates[::27] + dates[1::27] + dates[2::27]
+ dates[3::27] + dates[4::27] + dates[5::27]
+ dates[6::27] + dates[7::27])
val_dates = dates[2::27]
test_dates = dates[5::27]
train_dates = set(dates).difference(exclude_from_train)
date_splits = {"train": train_dates, "val": val_dates, "test": test_dates}
train_set = [dataset[f] for f in dataset.files if f[:8] in train_dates]
split_set = [dataset[f] for f in dataset.files if f[:8] in date_splits[split]]
super().__init__(train_set, split_set, split == "train")