def model_selection()

in covid19_spread/cross_val.py [0:0]


    def model_selection(self, basedir: str, config, module) -> List[BestRun]:
        """
        Evaluate a sweep returning a list of models to retrain on the full dataset.
        """
        df = self.metric_df(basedir)
        if "ablation" in config["train"]:
            ablation_map = defaultdict(count().__next__)
            ablations = []
            for _, row in df.iterrows():
                job_cfg = load_config(os.path.join(row.pth, f"{module}.yml"))
                if (
                    job_cfg["train"]["ablation"] is not None
                    and len(job_cfg["train"]["ablation"]) > 0
                ):
                    ablation = ",".join(
                        os.path.basename(x) for x in job_cfg["train"]["ablation"]
                    )
                else:
                    ablation = "null"
                ablations.append(ablation)
                ablation_map[ablation]
            ablation_map = {k: f"ablation_{v}" for k, v in ablation_map.items()}
            rev_map = {v: k for k, v in ablation_map.items()}
            df["ablation"] = [ablation_map[x] for x in ablations]
            with open(os.path.join(basedir, "ablation_map.json"), "w") as fout:
                print(json.dumps(rev_map), file=fout)
            best_runs = []
            for key in ["mae", "rmse", "mae_deltas", "rmse_deltas"]:
                best = df.loc[df.groupby("ablation")[key].idxmin()]
                best_runs.extend(
                    [
                        BestRun(x.pth, f"best_{key}_{x.ablation}")
                        for _, x in best.iterrows()
                    ]
                )
            return best_runs

        return [
            BestRun(df.sort_values(by="mae").iloc[0].pth, "best_mae"),
            BestRun(df.sort_values(by="rmse").iloc[0].pth, "best_rmse"),
            BestRun(df.sort_values(by="mae_deltas").iloc[0].pth, "best_mae_deltas"),
            BestRun(df.sort_values(by="rmse_deltas").iloc[0].pth, "best_rmse_deltas"),
            BestRun(df.sort_values(by="state_mae").iloc[0].pth, "best_state_mae"),
        ]