def plot_curves()

in pycls/sweep/plotting.py [0:0]


def plot_curves(sweeps, names, metric, n_curves, reverse=False):
    """Plots metric versus epoch for up to best n_curves jobs per sweep."""
    ms = [min(n_curves, len(sweep)) for sweep in sweeps]
    m, n = max(ms), len(sweeps)
    fig, axes = fig_make(m, n, False, sharex=False, sharey=True)
    sweeps = [sort_sweep(sweep, "error", reverse)[0] for sweep in sweeps]
    xs_trn = [get_vals(sweep, "train_epoch.epoch_ind") for sweep in sweeps]
    xs_tst = [get_vals(sweep, "test_epoch.epoch_ind") for sweep in sweeps]
    xs_ema = [get_vals(sweep, "test_ema_epoch.epoch_ind") for sweep in sweeps]
    xs_max = [get_vals(sweep, "test_ema_epoch.epoch_max") for sweep in sweeps]
    ys_trn = [get_vals(sweep, "train_epoch." + metric) for sweep in sweeps]
    ys_tst = [get_vals(sweep, "test_epoch." + metric) for sweep in sweeps]
    ys_ema = [get_vals(sweep, "test_ema_epoch." + metric) for sweep in sweeps]
    ticks = [1, 2, 4, 8, 16, 32, 64, 100]
    y_min = min(min(y) for y in ys_ema + ys_tst for y in y)
    y_min = ticks[np.argmin(np.asarray(ticks) <= y_min) - 1]
    for i, j in [(i, j) for j in range(n) for i in range(ms[j])]:
        ax, x_max = axes[i][j], xs_max[j][i][-1]
        x_trn, y_trn, e_trn = xs_trn[j][i], ys_trn[j][i], min(ys_trn[j][i])
        x_tst, y_tst, e_tst = xs_tst[j][i], ys_tst[j][i], min(ys_tst[j][i])
        x_ema, y_ema, e_ema = xs_ema[j][i], ys_ema[j][i], min(ys_ema[j][i])
        label, prop = "{} {:5.2f}", {"color": get_color(j), "alpha": 0.8}
        ax.plot(x_trn, y_trn, "--", **prop, label=label.format("trn", e_trn))
        ax.plot(x_tst, y_tst, ":", **prop, label=label.format("tst", e_tst))
        ax.plot(x_ema, y_ema, "-", **prop, label=label.format("ema", e_ema))
        ax.plot([x_ema[0], x_ema[-1]], [e_ema, e_ema], "-", color="k", alpha=0.8)
        xy_good = [(x, y) for x, y in zip(x_ema, y_ema) if y < 1.01 * e_ema]
        ax.scatter(*zip(*xy_good), **prop, s=10)
        ax.scatter([np.argmin(y_ema) + 1], e_ema, **prop)
        ax.legend(loc="upper right")
        ax.set_xlim(right=x_max)
    for i, j in [(i, j) for i in range(m) for j in range(n)]:
        ax = axes[i][j]
        ax.set_xlabel("epoch" if i == m - 1 else "")
        ax.set_ylabel(metric if j == 0 else "")
        ax.set_yscale("log", base=2)
        ax.set_yticks(ticks)
        ax.set_yticklabels(ticks)
        ax.set_yticks([t * np.sqrt(2) for t in ticks], minor=True)
        ax.set_yticklabels([], minor=True)
        ax.set_ylim(bottom=y_min, top=100)
        ax.yaxis.grid(True, which="minor")
    fig_legend(fig, n, names, styles="-", markers="")
    return fig