def correlations()

in scripts/make_figures.py [0:0]


def correlations(results_path, save_path, prefix):
    results = load_results(results_path, f"{prefix}_pca20.json")
    etas = np.array(results["etas"])
    n_samples = 2000
    np.random.seed(n_samples)
    samples = np.random.permutation(len(etas))[:n_samples]
    losses = np.array(results["train_losses"])
    grad_norms = np.array(results["train_grad_norms"])
    alternatives = [
        ("loss", "(a) Loss $\ell({\\bf w^*}^\\top {\\bf x}, y)$", losses),
        ("gradnorm", "(b) Gradient norm $\|\\nabla_{\\bf w^*} \ell\|_2$", grad_norms)]
    f, axarr = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
    f.subplots_adjust(wspace=0.1)
    for e, (method, xlabel, values) in enumerate(alternatives):
        ax = axarr[e]
        ax.scatter(values[samples], etas[samples], s=2.5, color=COLOR)
        ax.set_xlabel(xlabel)
    axarr[0].set_ylabel("FIL $\eta$")

    plotting.savefig(os.path.join(save_path, f"{prefix}_scatter_alternatives_eta"))
    plt.clf()