def get_ploting_params()

in plot_toy_scatter.py [0:0]


    def get_ploting_params(df):
        models = {
            (exp, seed): load_model(path.replace(".pt", ".best.pt"))
            for exp, seed, path in (
                df.groupby(["method", "init_seed", "file_path"]).groups.keys()
            )
        }

        df = (
            df.melt(
                id_vars=idx,
                value_vars=["min_acc_va", "min_acc_te", "min_acc_tr"],
                var_name="phase",
                value_name="error",
            )
            .replace({"min_acc_va": "valid", "min_acc_te": "test", "min_acc_tr": "train"})
            .reset_index()
        )

        datasets = []
        for i in range(seeds):
            torch.manual_seed(i)
            np.random.seed(i)
            d = Toy("tr")
            datasets.append((d.x, d.y))

        all_hm = torch.zeros(len(exps), seeds, 200 * 200)
        for exp_i, exp in enumerate(exps):
            for i in range(seeds):
                heatmap_plane = generate_heatmap_plane(datasets[i][0]).to(DEVICE)
                all_hm[exp_i, i] = models[(exp, i)](heatmap_plane).detach().cpu()
        return exps, datasets, all_hm, gammas, heatmap_plane, df