07_training/serverlessml/flowers/utils/plots.py (20 lines of code) (raw):

#!/usr/bin/env python # Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. import matplotlib.pylab as plt import numpy as np import os, shutil, tempfile, subprocess def training_plot(metrics, history, filename): f, ax = plt.subplots(1, len(metrics), figsize=(5*len(metrics), 5)) for idx, metric in enumerate(metrics): ax[idx].plot(history.history[metric], ls='dashed') ax[idx].set_xlabel("Epochs") ax[idx].set_ylabel(metric) ax[idx].plot(history.history['val_' + metric]); ax[idx].legend([metric, 'val_' + metric]) on_cloud = filename.startswith('gs://') if on_cloud: with tempfile.TemporaryDirectory() as tmpdir: tmpfilename = os.path.join(tmpdir, "out.png") plt.savefig(tmpfilename) subprocess.check_call('gsutil cp {} {}'.format( tmpfilename, filename).split()) else: plt.savefig(filename)