in pycls/sweep/plotting.py [0:0]
def plot_curves(sweeps, names, metric, n_curves, reverse=False):
"""Plots metric versus epoch for up to best n_curves jobs per sweep."""
ms = [min(n_curves, len(sweep)) for sweep in sweeps]
m, n = max(ms), len(sweeps)
fig, axes = fig_make(m, n, False, sharex=False, sharey=True)
sweeps = [sort_sweep(sweep, "error", reverse)[0] for sweep in sweeps]
xs_trn = [get_vals(sweep, "train_epoch.epoch_ind") for sweep in sweeps]
xs_tst = [get_vals(sweep, "test_epoch.epoch_ind") for sweep in sweeps]
xs_ema = [get_vals(sweep, "test_ema_epoch.epoch_ind") for sweep in sweeps]
xs_max = [get_vals(sweep, "test_ema_epoch.epoch_max") for sweep in sweeps]
ys_trn = [get_vals(sweep, "train_epoch." + metric) for sweep in sweeps]
ys_tst = [get_vals(sweep, "test_epoch." + metric) for sweep in sweeps]
ys_ema = [get_vals(sweep, "test_ema_epoch." + metric) for sweep in sweeps]
ticks = [1, 2, 4, 8, 16, 32, 64, 100]
y_min = min(min(y) for y in ys_ema + ys_tst for y in y)
y_min = ticks[np.argmin(np.asarray(ticks) <= y_min) - 1]
for i, j in [(i, j) for j in range(n) for i in range(ms[j])]:
ax, x_max = axes[i][j], xs_max[j][i][-1]
x_trn, y_trn, e_trn = xs_trn[j][i], ys_trn[j][i], min(ys_trn[j][i])
x_tst, y_tst, e_tst = xs_tst[j][i], ys_tst[j][i], min(ys_tst[j][i])
x_ema, y_ema, e_ema = xs_ema[j][i], ys_ema[j][i], min(ys_ema[j][i])
label, prop = "{} {:5.2f}", {"color": get_color(j), "alpha": 0.8}
ax.plot(x_trn, y_trn, "--", **prop, label=label.format("trn", e_trn))
ax.plot(x_tst, y_tst, ":", **prop, label=label.format("tst", e_tst))
ax.plot(x_ema, y_ema, "-", **prop, label=label.format("ema", e_ema))
ax.plot([x_ema[0], x_ema[-1]], [e_ema, e_ema], "-", color="k", alpha=0.8)
xy_good = [(x, y) for x, y in zip(x_ema, y_ema) if y < 1.01 * e_ema]
ax.scatter(*zip(*xy_good), **prop, s=10)
ax.scatter([np.argmin(y_ema) + 1], e_ema, **prop)
ax.legend(loc="upper right")
ax.set_xlim(right=x_max)
for i, j in [(i, j) for i in range(m) for j in range(n)]:
ax = axes[i][j]
ax.set_xlabel("epoch" if i == m - 1 else "")
ax.set_ylabel(metric if j == 0 else "")
ax.set_yscale("log", base=2)
ax.set_yticks(ticks)
ax.set_yticklabels(ticks)
ax.set_yticks([t * np.sqrt(2) for t in ticks], minor=True)
ax.set_yticklabels([], minor=True)
ax.set_ylim(bottom=y_min, top=100)
ax.yaxis.grid(True, which="minor")
fig_legend(fig, n, names, styles="-", markers="")
return fig