def prediction_interval()

in cv.py [0:0]


def prediction_interval(chkpnts, remote, nsamples, batchsize, closed_form):
    def f(chkpnt_pth):
        prefix = "final_model_" if "final_model_" in chkpnt_pth else ""
        chkpnt = th.load(chkpnt_pth)
        job_pth = os.path.dirname(chkpnt_pth)

        cfg_pth = os.path.join(job_pth, "../cfg.yml")
        if not os.path.exists(cfg_pth):
            cfg_pth = os.path.join(job_pth, "../../cfg.yml")
        cfg = load_config(cfg_pth)
        module = cfg["this_module"]
        job_config = load_config(os.path.join(job_pth, f"{prefix}{module}.yml"))
        opt = Namespace(**job_config["train"])
        mod = importlib.import_module("covid19_spread." + module).CV_CLS()
        new_cases, regions, basedate, device = mod.initialize(opt)
        model = mod.func
        model.load_state_dict(chkpnt)

        dset = os.path.join(
            job_pth, prefix + "preprocessed_" + os.path.basename(job_config["data"])
        )
        val_in = os.path.join(
            job_pth, prefix + "filtered_" + os.path.basename(job_config["data"])
        )

        gt = metrics.load_ground_truth(val_in)
        prev_day = gt.loc[[pd.to_datetime(basedate)]]
        pred_interval = cfg.get("prediction_interval", {})
        df_std, df_mean = mod.run_standard_deviation(
            dset,
            opt,
            nsamples or pred_interval.get("nsamples", 100),
            pred_interval.get("intervals", [0.99, 0.95, 0.8]),
            prev_day.values.T,
            model,
            batchsize or pred_interval.get("batch_size", 8),
            closed_form=closed_form,
        )
        suffix = "_closed_form" if closed_form else ""
        df_std.to_csv(
            os.path.join(job_pth, f"{prefix}std{suffix}.csv"), index_label="date"
        )
        df_mean.to_csv(
            os.path.join(job_pth, f"{prefix}mean{suffix}.csv"), index_label="date"
        )
        pred_intervals = mod.run_prediction_interval(
            os.path.join(job_pth, f"{prefix}mean{suffix}.csv"),
            os.path.join(job_pth, f"{prefix}std{suffix}.csv"),
            pred_interval.get("intervals", [0.99, 0.95, 0.8]),
        )
        pred_intervals.to_csv(
            os.path.join(job_pth, f"{prefix}piv{suffix}.csv"), index=False
        )

    if remote:
        now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        folder = os.path.expanduser(f"~/.covid19/logs/{now}")
        extra_params = {"gpus": 1, "cpus": 2, "memgb": 20, "timeout": 3600}
        ex = mk_executor(
            "prediction_interval", folder, extra_params, ex=submitit.AutoExecutor
        )
        ex.map_array(f, chkpnts)
        print(folder)
    else:
        list(map(f, chkpnts))