in scripts/plot_results.py [0:0]
def plot_table(table, dirname, file_name, save=True, block=False, fontsize=12):
fig, axs = plt.subplots(1, 6, figsize=(13, 2.1))
axs = axs.flat
width = None
for id_d, dataset in zip(range(len(table.keys())), sorted(table.keys())):
models = table[dataset]
envs = models[list(models.keys())[0]].keys()
if not width:
width = 1 / (len(envs) + 1)
legends = []
for id_e, env in zip(range(len(envs)), envs):
labels = sorted(models.keys())
pos = np.arange(len(labels))
model_means = [models[model][env]['mean']
for model in sorted(models.keys())]
model_stds = [models[model][env]['std']
for model in sorted(models.keys())]
l = axs[id_d].bar(pos + id_e * width, model_means,
width=width, color=f'C{id_e}', label=f'E{env}',
align='center', ecolor=f'black', capsize=3, yerr=model_stds,
)
legends.append(l)
axs[id_d].set_title(dataset)
axs[id_d].set_xticks(pos + width * (len(envs) / 2 - 0.5))
axs[id_d].set_xticklabels(labels, fontsize=7)
axs[id_d].set_ylim(bottom=0)
axs[0].set_ylabel('Test error')
plt.tight_layout(pad=0)
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.legend(handles=legends,
ncol=6,
loc="lower center",
bbox_to_anchor=(-2.8, -0.4))
if save:
fig_dirname = "figs/"
os.makedirs(fig_dirname, exist_ok=True)
models = '_'.join(sorted(models.keys()))
plt.savefig(fig_dirname + file_name + '_' + models +'.pdf',
format='pdf', bbox_inches='tight')
if block:
plt.show(block=False)
input('Press to close')
plt.close('all')