in scripts/plot_results.py [0:0]
def plot_table_avg(table, dirname, file_name, save=True, block=False, fontsize=12):
table = table["data"]
fig, axs = plt.subplots(1, 6, figsize=(13, 2.1))
axs = axs.flat
width = 0.5
for id_d, dataset in zip(range(len(table.keys())), sorted(table.keys())):
models = table[dataset]
labels = sorted(models.keys())
pos = np.arange(len(labels))
model_means = [models[model]['mean']
for model in sorted(models.keys())]
model_stds = [models[model]['std']
for model in sorted(models.keys())]
legends = []
for id_m in range(len(pos)):
l, = axs[id_d].bar(pos[id_m], model_means[id_m],
width=width, color=f'C{id_m}',
align='center', ecolor='black',
capsize=7, yerr=model_stds[id_m], linewidth=0.1
)
legends.append(labels[id_m])
axs[id_d].set_title(dataset)
axs[id_d].set_xticks(pos)
axs[id_d].set_ylim(bottom=0)
axs[0].set_ylabel('Test error')
plt.tight_layout(pad=0)
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.legend(legends,
ncol=6,
loc="lower center",
bbox_to_anchor=(-2.8, -0.5))
if save:
fig_dirname = "figs/"
os.makedirs(fig_dirname, exist_ok=True)
models = '_'.join(sorted(models.keys()))
plt.savefig(fig_dirname + file_name + '_' + models +'.pdf',
format='pdf', bbox_inches='tight')
if block:
plt.show(block=False)
input('Press to close')
plt.close('all')