in 07_training/serverlessml/flowers/utils/plots.py [0:0]
def training_plot(metrics, history, filename):
f, ax = plt.subplots(1, len(metrics), figsize=(5*len(metrics), 5))
for idx, metric in enumerate(metrics):
ax[idx].plot(history.history[metric], ls='dashed')
ax[idx].set_xlabel("Epochs")
ax[idx].set_ylabel(metric)
ax[idx].plot(history.history['val_' + metric]);
ax[idx].legend([metric, 'val_' + metric])
on_cloud = filename.startswith('gs://')
if on_cloud:
with tempfile.TemporaryDirectory() as tmpdir:
tmpfilename = os.path.join(tmpdir, "out.png")
plt.savefig(tmpfilename)
subprocess.check_call('gsutil cp {} {}'.format(
tmpfilename, filename).split())
else:
plt.savefig(filename)