def plot_itrs()

in visualization/plotting.py [0:0]


def plot_itrs(save_fname='itr.pdf', val=False):
    nodes, fpaths, tags, legends, colors = get_eth_config()

    fig = plt.figure(1)
    ax = fig.add_subplot(111)
    for n, ntags, nlegends, ncolors in zip(nodes, zip(*tags), zip(*legends),
                                           zip(*colors)):  # iterate over nodes
        if n != 16 and n != 8:
            continue
        for i in range(len(ntags)):  # iterate over algs
            df = parse_csv(n, ntags[i], fpaths[i])
            color = ncolors[i]
            label = nlegends[i]
            if n == 8:
                linestyle = '--'
            else:
                linestyle = '-'
            if val:
                df.plot(x='time', y='val_mean', ax=ax, color=color,
                        grid=True, label=label, fontsize=16)
            else:
                df.plot(x='time', y='train_mean', ax=ax, color=color,
                        grid=True, label=label, fontsize=16,
                        linestyle=linestyle)

    if val:
        ax.set_ylabel('Validation Error (%)', fontsize=font_size)
    else:
        ax.set_ylabel('Training Error (%)', fontsize=font_size)
    ax.set_xlabel('Time (s)', fontsize=font_size)
    # ax.set_xlabel('Iterations', fontsize=font_size)
    ax.grid(which='both')
    ax.grid(which='minor', alpha=0.2)
    ax.grid(which='major', alpha=0.5)
    plt.legend(prop={'size': 16})
    plt.tight_layout()
    fig.savefig(save_fname)
    plt.show()