in plot_toy_scatter.py [0:0]
def get_ploting_params(df):
models = {
(exp, seed): load_model(path.replace(".pt", ".best.pt"))
for exp, seed, path in (
df.groupby(["method", "init_seed", "file_path"]).groups.keys()
)
}
df = (
df.melt(
id_vars=idx,
value_vars=["min_acc_va", "min_acc_te", "min_acc_tr"],
var_name="phase",
value_name="error",
)
.replace({"min_acc_va": "valid", "min_acc_te": "test", "min_acc_tr": "train"})
.reset_index()
)
datasets = []
for i in range(seeds):
torch.manual_seed(i)
np.random.seed(i)
d = Toy("tr")
datasets.append((d.x, d.y))
all_hm = torch.zeros(len(exps), seeds, 200 * 200)
for exp_i, exp in enumerate(exps):
for i in range(seeds):
heatmap_plane = generate_heatmap_plane(datasets[i][0]).to(DEVICE)
all_hm[exp_i, i] = models[(exp, i)](heatmap_plane).detach().cpu()
return exps, datasets, all_hm, gammas, heatmap_plane, df