in cv.py [0:0]
def run_best(config, module, remote, basedir, basedate=None, executor=None):
mod = importlib.import_module("covid19_spread." + module).CV_CLS()
sweep_config = load_config(os.path.join(basedir, "cfg.yml"))
best_runs = mod.model_selection(basedir, config=sweep_config[module], module=module)
if remote and executor is None:
executor = mk_executor(
"model_selection", basedir, config[module].get("resources", {})
)
with open(os.path.join(basedir, "model_selection.json"), "w") as fout:
json.dump([x._asdict() for x in best_runs], fout)
cfg = copy.deepcopy(config)
best_runs_df = pd.DataFrame(best_runs)
def run_cv_and_copy_results(tags, module, pth, cfg, prefix):
try:
jobs = run_cv(
module,
pth,
cfg,
prefix=prefix,
basedate=basedate,
executor=executor,
test_run=True,
)
def rest():
for tag in tags:
shutil.copy(
os.path.join(pth, f'final_model_{cfg["validation"]["output"]}'),
os.path.join(
os.path.dirname(pth), f"forecasts/forecast_{tag}.csv"
),
)
if "prediction_interval" in cfg:
piv_pth = os.path.join(
pth,
f'final_model_{cfg["prediction_interval"]["output_std"]}',
)
if os.path.exists(piv_pth):
shutil.copy(
piv_pth,
os.path.join(
os.path.dirname(pth), f"forecasts/std_{tag}.csv"
),
)
if cfg[module]["train"].get("n_models", 1) > 1:
executor.submit_dependent(jobs, rest)
else:
rest()
except Exception as e:
msg = f"*Final run failed for {tags}*\nbasedir = {basedir}\nException was: {e}"
post_slack_message(channel="#cron_errors", text=msg)
raise e
for pth, tags in best_runs_df.groupby("pth")["name"].agg(list).items():
os.makedirs(os.path.join(os.path.dirname(pth), "forecasts"), exist_ok=True)
name = ",".join(tags)
print(f"Starting {name}: {pth}")
job_config = load_config(os.path.join(pth, module + ".yml"))
if "test" in cfg:
job_config["train"]["test_on"] = cfg["test"]["days"]
cfg[module] = job_config
launcher = run_cv_and_copy_results
if remote:
launcher = partial(executor.submit, run_cv_and_copy_results)
with executor.set_folder(pth) if remote else nullcontext():
launcher(tags, module, pth, cfg, "final_model_")