in cv.py [0:0]
def prediction_interval(chkpnts, remote, nsamples, batchsize, closed_form):
def f(chkpnt_pth):
prefix = "final_model_" if "final_model_" in chkpnt_pth else ""
chkpnt = th.load(chkpnt_pth)
job_pth = os.path.dirname(chkpnt_pth)
cfg_pth = os.path.join(job_pth, "../cfg.yml")
if not os.path.exists(cfg_pth):
cfg_pth = os.path.join(job_pth, "../../cfg.yml")
cfg = load_config(cfg_pth)
module = cfg["this_module"]
job_config = load_config(os.path.join(job_pth, f"{prefix}{module}.yml"))
opt = Namespace(**job_config["train"])
mod = importlib.import_module("covid19_spread." + module).CV_CLS()
new_cases, regions, basedate, device = mod.initialize(opt)
model = mod.func
model.load_state_dict(chkpnt)
dset = os.path.join(
job_pth, prefix + "preprocessed_" + os.path.basename(job_config["data"])
)
val_in = os.path.join(
job_pth, prefix + "filtered_" + os.path.basename(job_config["data"])
)
gt = metrics.load_ground_truth(val_in)
prev_day = gt.loc[[pd.to_datetime(basedate)]]
pred_interval = cfg.get("prediction_interval", {})
df_std, df_mean = mod.run_standard_deviation(
dset,
opt,
nsamples or pred_interval.get("nsamples", 100),
pred_interval.get("intervals", [0.99, 0.95, 0.8]),
prev_day.values.T,
model,
batchsize or pred_interval.get("batch_size", 8),
closed_form=closed_form,
)
suffix = "_closed_form" if closed_form else ""
df_std.to_csv(
os.path.join(job_pth, f"{prefix}std{suffix}.csv"), index_label="date"
)
df_mean.to_csv(
os.path.join(job_pth, f"{prefix}mean{suffix}.csv"), index_label="date"
)
pred_intervals = mod.run_prediction_interval(
os.path.join(job_pth, f"{prefix}mean{suffix}.csv"),
os.path.join(job_pth, f"{prefix}std{suffix}.csv"),
pred_interval.get("intervals", [0.99, 0.95, 0.8]),
)
pred_intervals.to_csv(
os.path.join(job_pth, f"{prefix}piv{suffix}.csv"), index=False
)
if remote:
now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
folder = os.path.expanduser(f"~/.covid19/logs/{now}")
extra_params = {"gpus": 1, "cpus": 2, "memgb": 20, "timeout": 3600}
ex = mk_executor(
"prediction_interval", folder, extra_params, ex=submitit.AutoExecutor
)
ex.map_array(f, chkpnts)
print(folder)
else:
list(map(f, chkpnts))