def training_plot()

in 07_training/serverlessml/flowers/utils/plots.py [0:0]


def training_plot(metrics, history, filename):
    f, ax = plt.subplots(1, len(metrics), figsize=(5*len(metrics), 5))
    for idx, metric in enumerate(metrics):
        ax[idx].plot(history.history[metric], ls='dashed')
        ax[idx].set_xlabel("Epochs")
        ax[idx].set_ylabel(metric)
        ax[idx].plot(history.history['val_' + metric]);
        ax[idx].legend([metric, 'val_' + metric])
    
    on_cloud = filename.startswith('gs://')
    if on_cloud:
        with tempfile.TemporaryDirectory() as tmpdir:
            tmpfilename = os.path.join(tmpdir, "out.png")
            plt.savefig(tmpfilename)
            subprocess.check_call('gsutil cp {} {}'.format(
                tmpfilename, filename).split())
    else:
        plt.savefig(filename)