def plot_models()

in pycls/sweep/plotting.py [0:0]


def plot_models(sweeps, names, n_models, reverse=False):
    """Plots model visualization for up to n_models per sweep."""
    ms = [min(n_models, len(sweep)) for sweep in sweeps]
    m, n = max(ms), len(sweeps)
    fig, axes = fig_make(m, n, False, sharex=True, sharey=True)
    sweeps = [sort_sweep(sweep, "error", reverse)[0] for sweep in sweeps]
    for i, j in [(i, j) for j in range(n) for i in range(ms[j])]:
        ax, sweep, color = axes[i][j], [sweeps[j][i]], get_color(j)
        metrics = ["error", "flops", "params", "acts", "epoch_fw_bw", "resolution"]
        vals = [get_vals(sweep, m)[0] for m in metrics]
        label = "e = {:.2f}%, f = {:.2f}B\n".format(*vals[0:2])
        label += "p = {:.2f}M, a = {:.2f}M\n".format(*vals[2:4])
        label += "t = {0:.0f}s, r = ${1:d} \\times {1:d}$\n".format(*vals[4:6])
        model_type = get_vals(sweep, "cfg.MODEL.TYPE")[0]
        if model_type == "regnet":
            metrics = ["GROUP_W", "BOT_MUL", "WA", "W0", "WM", "DEPTH"]
            vals = [get_vals(sweep, "cfg.REGNET." + m)[0] for m in metrics]
            ws, ds, _, _, _, ws_cont = regnet.generate_regnet(*vals[2:])
            label += "$d_i = {:s}$\n$w_i = {:s}$\n".format(str(ds), str(ws))
            label += "$g={:d}$, $b={:g}$, $w_a={:.1f}$\n".format(*vals[:3])
            label += "$w_0={:d}$, $w_m={:.3f}$".format(*vals[3:5])
            ax.plot(ws_cont, ":", c=color)
        elif model_type == "anynet":
            metrics = ["anynet_ds", "anynet_ws", "anynet_gs", "anynet_bs"]
            ds, ws, gs, bs = [get_vals(sweep, m)[0] for m in metrics]
            label += "$d_i = {:s}$\n$w_i = {:s}$\n".format(str(ds), str(ws))
            label += "$g_i = {:s}$\n$b_i = {:s}$".format(str(gs), str(bs))
        elif model_type == "effnet":
            metrics = ["effnet_ds", "effnet_ws", "effnet_ss", "effnet_bs"]
            ds, ws, ss, bs = [get_vals(sweep, m)[0] for m in metrics]
            label += "$d_i = {:s}$\n$w_i = {:s}$\n".format(str(ds), str(ws))
            label += "$s_i = {:s}$\n$b_i = {:s}$".format(str(ss), str(bs))
        else:
            raise AssertionError("Unknown model type" + model_type)
        ws_all = [w for ws in [[w] * d for d, w in zip(ds, ws)] for w in ws]
        ds_cum = np.cumsum([0] + ds[0:-1])
        ax.plot(ws_all, "o-", c=color, markersize=plt.rcParams["lines.markersize"] - 1)
        ax.plot(ds_cum, ws, "o", c="k", fillstyle="none", label=label)
        ax.legend(loc="lower right", markerscale=0, handletextpad=0, handlelength=0)
    for i, j in [(i, j) for i in range(m) for j in range(n)]:
        ax = axes[i][j]
        ax.set_xlabel("block index" if i == m - 1 else "")
        ax.set_ylabel("width" if j == 0 else "")
        ax.set_yscale("log", base=2)
        ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
        ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    fig_legend(fig, n, names, styles="-")
    return fig