in reproduce/plot_test_error.py [0:0]
def plot_averaged(plot_name, plot_entries, xvals, yrange=None):
run_dir = "runs"
plot_dir = "plots"
#plot_name = "test_error1"
loss_key = "test_errors"
ylabel = "Test error (%)"
# Positions of the labels in x axis range, i.e. epoch number
#xvals
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
max_seen = 0
plt.cla()
fig = plt.figure(figsize=(3.3,2))
ax = fig.add_subplot(111)
ax.set_prop_cycle("color", colors)
#ax.set_prop_cycle("linestyle", linestyles)
line_idx = 0
for plot_entry in plot_entries:
fname_grob = plot_entry["fname"]
data_files = glob.glob(fname_grob)
if len(data_files) == 0:
raise Exception("No files found matching path: {}".format(fname_grob))
errors_lists = []
for fname in data_files:
print("(ALL) processing run ", fname)
with open(fname, 'rb') as fdata:
rd = pickle.load(fdata)
values = rd[loss_key]
# convert to errors
errors = [100.0 - val for val in values]
#pdb.set_trace()
#print("losses: {}".format(losses))
print("Final test error {} for {}".format(errors[-1], plot_entry["label"]))
errors_lists.append(errors.copy())
max_test_loss = max(errors)
if max_test_loss > max_seen:
max_seen = max_test_loss
max_epoch = len(errors)
## Aggregate and plots
n = len(errors_lists)
errors_avg = [0.0 for i in range(len(errors_lists[0]))]
for i in range(n):
for j in range(len(errors_avg)):
errors_avg[j] += float(errors_lists[i][j]/n)
#pdb.set_trace()
ax.plot(
range(len(errors_avg)),
errors_avg,
label=plot_entry["label"],
linestyle=linestyles[line_idx]) #linestyle=next(linestyles)
line_idx += 1
print("Average final test error {} for {}".format(errors_avg[-1], plot_entry["label"]))
print("Finalizing plot")
plt.xlabel('Epoch')
plt.ylabel(ylabel)
plt.xlim([0, max_epoch])
if yrange is not None:
plt.ylim(yrange)
else:
plt.ylim([0, max_seen])
#box = ax.get_position()
#ax.set_position([box.x0, box.y0, box.width * 0.6, box.height])
ax.grid(False)
ax.xaxis.set_tick_params(direction='in')
ax.yaxis.set_tick_params(direction='in', right="on")
labelLines(plt.gca().get_lines(), align=False, fontsize=label_fontsize, xvals=xvals)
#ax.legend(fontsize=5, handlelength=8, loc='center left', bbox_to_anchor=(1, 0.5))
figname = "{}/{}.pdf".format(plot_dir, plot_name)
fig.savefig(figname, bbox_inches='tight', pad_inches=0)
print("saved", figname)