def private_mse_and_fil()

in scripts/make_figures.py [0:0]


def private_mse_and_fil(results_path, save_path):
    L2s = ['1e-5', '1e-3', '1e-1', '1']
    noise_scales = [
        '1e-05', '2e-05', '5e-05', '0.0001', '0.0002',
        '0.0005', '0.001', '0.002', '0.005', '0.01',
        '0.02', '0.05', '0.1', '0.2', '0.5', '1.0',
    ]

    fils = []
    mean_etas = []
    mses = []
    for l2 in L2s:
        etas = load_results(
            results_path, f"iwpc_least_squares_fil_l2_{l2}.json")["etas"]
        fils.append(etas)
        inversion_results = load_results(
            results_path,
            f"iwpc_least_squares_whitebox_private_inversion_l2_{l2}.json")
        mses.append([inversion_results[0][noise_scale]['test_acc']
            for noise_scale in noise_scales])
        mean_etas.append(np.mean(etas))
    sigmas = np.array([float(ns) for ns in noise_scales])
    mean_etas = np.array(mean_etas)[:, None] / sigmas

    l2s = np.array([float(l2) for l2 in L2s])
    legend = ["$\lambda=10^{%d}$"%int(math.log10(float(l2))) for l2 in L2s]

    # Plot FILs:
    num_bins = 100
    fil_counts = []
    fil_centers = []
    for fil in fils:
        lower = math.log10(np.min(fil))
        upper = math.log10(np.max(fil) + 1e-4)
        bins = np.logspace(lower, upper, num_bins + 1)
        counts, edges = np.histogram(fil, bins=bins)
        centers = (edges[:-1] + edges[1:]) / 2
        fil_counts.append(counts)
        fil_centers.append(centers)
    plotting.line_plot(
        np.array(fil_counts),
        np.array(fil_centers),
        xlabel="FIL $\eta$ (at $\sigma=1$)",
        ylabel="Number of examples",
        legend=legend,
        marker=None,
        size=(5, 5),
        xlog=True,
        filename=os.path.join(
            args.save_path,
            f"iwpc_fil_counts_varying_l2"),
    )

    # PLot MSEs
    mses = np.array(mses) # [L2, noise_scale, trials]
    mse_means = mses.mean(axis=2)
    mse_stds = mses.std(axis=2)
    plotting.line_plot(
        mse_means, mean_etas, legend=legend,
        xlabel="Mean $\\bar{\eta}$",
        ylabel="Test MSE",
        ylog=True,
        xlog=True,
        size=(5, 5),
        errors=mse_stds,
        filename=os.path.join(args.save_path, f"iwpc_mse_vs_eta_varying_l2"),
    )