in visualization/plotting.py [0:0]
def plot_itrs(save_fname='itr.pdf', val=False):
nodes, fpaths, tags, legends, colors = get_eth_config()
fig = plt.figure(1)
ax = fig.add_subplot(111)
for n, ntags, nlegends, ncolors in zip(nodes, zip(*tags), zip(*legends),
zip(*colors)): # iterate over nodes
if n != 16 and n != 8:
continue
for i in range(len(ntags)): # iterate over algs
df = parse_csv(n, ntags[i], fpaths[i])
color = ncolors[i]
label = nlegends[i]
if n == 8:
linestyle = '--'
else:
linestyle = '-'
if val:
df.plot(x='time', y='val_mean', ax=ax, color=color,
grid=True, label=label, fontsize=16)
else:
df.plot(x='time', y='train_mean', ax=ax, color=color,
grid=True, label=label, fontsize=16,
linestyle=linestyle)
if val:
ax.set_ylabel('Validation Error (%)', fontsize=font_size)
else:
ax.set_ylabel('Training Error (%)', fontsize=font_size)
ax.set_xlabel('Time (s)', fontsize=font_size)
# ax.set_xlabel('Iterations', fontsize=font_size)
ax.grid(which='both')
ax.grid(which='minor', alpha=0.2)
ax.grid(which='major', alpha=0.5)
plt.legend(prop={'size': 16})
plt.tight_layout()
fig.savefig(save_fname)
plt.show()