graph.py (221 lines of code) (raw):
import argparse
import subprocess
import shlex
import os
import glob
import json
import time
from matplotlib import pyplot as plt
import numpy as np
LOGDIR = os.path.expanduser('~/bigtrans_logs')
GRAPHDIR = os.path.expanduser('~/bigtrans_graphs')
BUCKET = '<input bucket>'
os.makedirs(LOGDIR, exist_ok=True)
os.makedirs(GRAPHDIR, exist_ok=True)
EPOCH_VARS = frozenset(['epoch', 'n_epochs'])
x_var_mapping = {
'epoch': 'n_epochs',
'step': 'n_updates'
}
y_var_mapping = {
'eval_loss': {
'loss': 'valid_gen_loss',
'loss_clf': 'valid_clf_loss',
'acc_clf': 'valid_acc',
},
'train_loss': {
'loss_avg': 'train_gen_loss',
'loss_clf_avg': 'train_clf_loss',
}
}
class Series(object):
def __init__(self, logpath, model_name, series_id, x_var, y_var, average, base=None, convert_to_epochs=False, legend=None):
self.name = model_name
if legend:
self.name += ":" + legend
with open(logpath, 'r') as f:
lines = f.readlines()
identifier = json.loads(lines[0])
img_gen_repr_learn = False
if 'code' in identifier:
img_gen_repr_learn = True
if img_gen_repr_learn:
x_var = x_var_mapping[x_var]
y_var = y_var_mapping[series_id][y_var]
data = []
epoch_length = None
for l in lines[1:]:
try:
parse = json.loads(l)
if epoch_length is None and 'n_updates_per_epoch' in parse:
epoch_length = float(parse['n_updates_per_epoch'])
if img_gen_repr_learn:
data.append(parse)
elif 'series' in parse and parse['series'] == series_id:
data.append(parse)
except json.JSONDecodeError:
pass
data = [d for d in data if x_var in d and y_var in d]
self.x = np.array([l[x_var] for l in data]).astype(np.float64)
self.y = np.array([l[y_var] for l in data]).astype(np.float64)
if convert_to_epochs and x_var not in EPOCH_VARS:
self.x /= epoch_length
if base is not None:
self.y /= np.log(base)
if average:
out_y = []
for j in range(1, len(self.y) + 1):
mini = max(0, j - args.average)
out_y.append(self.y[mini:j].mean())
self.y = np.array(out_y)
if len(self.x) > 0 and len(self.y) > 0:
max_idx = np.argmax(self.y)
min_idx = np.argmin(self.y)
self.xmax = self.x[max_idx]
self.ymax = self.y[max_idx]
self.xmin = self.x[min_idx]
self.ymin = self.y[min_idx]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# comma-separated model name substrings
parser.add_argument('--model', type=str)
parser.add_argument('--title', type=str, default=None)
parser.add_argument('--skip_cp', action="store_true")
parser.add_argument('--ylim', type=str, default="")
parser.add_argument('--xlim', type=str, default="")
parser.add_argument('--series', type=str, default="eval_loss:epoch:loss")
parser.add_argument('--average', type=int, default=None)
parser.add_argument('--train', action="store_true")
parser.add_argument('--valid', action="store_true")
parser.add_argument('--acc', action="store_true")
parser.add_argument('--clf_loss', action="store_true")
parser.add_argument('--train_valid', action="store_true")
parser.add_argument('--max', action="store_true")
parser.add_argument('--base', type=float)
parser.add_argument('--logy', action="store_true")
parser.add_argument('--logx', action="store_true")
parser.add_argument('--show', action='store_true')
args = parser.parse_args()
if not args.title:
args.title = args.model
# Basic sanity-checks
if args.acc and args.base is not None:
raise ValueError("Converting to other units is supported only for generative losses")
legends = [None]
if args.train:
args.series = 'train_loss:step:loss_avg'
if args.valid:
args.series = 'eval_loss:epoch:loss'
if args.acc:
args.series = 'eval_loss:epoch:acc_clf'
args.max = True
if args.clf_loss:
args.series = 'eval_loss:epoch:loss_clf'
if args.train_valid:
legends = ["valid", "train"]
if args.acc:
args.series = 'eval_loss:epoch:acc_clf,train_loss:step:loss_acc'
args.max = True
elif args.clf_loss:
args.series = 'eval_loss:epoch:loss_clf,train_loss:step:loss_clf_avg'
args.max = False
else:
args.series = 'eval_loss:epoch:loss,train_loss:step:loss_avg'
os.makedirs(LOGDIR, exist_ok=True)
strs = args.model.split(',')
print('Plotting models with names', strs)
prefix = BUCKET
suffix = '/log.jsonl'
names = []
sps = []
if not args.skip_cp:
files = []
for s in strs:
modelstr = f'{prefix}{s}{suffix}'
cmd = f'gsutil ls {modelstr}'
try:
o = subprocess.check_output(shlex.split(cmd))
files += [a.decode('utf-8') for a in o.split()]
except subprocess.CalledProcessError:
print(f'ls failed for {modelstr}')
for f in files:
name = f[len(prefix):-len(suffix)]
p = os.path.join(LOGDIR, name, 'log.jsonl')
cmd = f'gsutil cp {f} {p}'
sps.append(subprocess.Popen(shlex.split(cmd)))
while sps:
for proc in sps:
retcode = proc.poll()
if retcode is not None:
sps.remove(proc)
else:
time.sleep(0.1)
localpaths = []
for s in strs:
prefix = f'{LOGDIR}/'
suffix = 'log.jsonl'
for fp in glob.glob(os.path.join(prefix, s, suffix)):
localpaths.append((fp, fp[len(prefix):-len(suffix) - 1]))
# Series types define what to show as the train and validation curves.
series_types = args.series.split(',')
assert len(series_types) > 0
series = [[] for _ in series_types]
print('series to print:', series_types)
convert_to_epochs = set(srs.split(':')[1] in EPOCH_VARS for srs in series_types) == {True, False}
for logpath, model_name in localpaths:
for idx, (series_str, legend) in enumerate(zip(series_types, legends)):
series_id, x_var, y_var = series_str.split(':')
s = Series(logpath, model_name, series_id, x_var, y_var, args.average, base=args.base, convert_to_epochs=convert_to_epochs, legend=legend)
if len(s.x) > 0 and len(s.y) > 0:
series[idx].append(s)
assert len(series) > 0 and len(series[0]) > 0
cm = plt.cm.gist_rainbow
colors = cm(np.linspace(0, 1, len(series[0])))
if args.show:
plt.figure(figsize=(5, 5))
else:
plt.figure(figsize=(20, 20))
# # sort to keep colors consistent across plottings
for idx in range(len(series_types)):
series[idx].sort(key=lambda x: x.name)
ymin_data = []
ymax_data = []
# For --train_valid, validation curve will be shown in solid line by
# default.
linestyles = ["-", "--"]
for srs_list, style in zip(reversed(series), reversed(linestyles[:len(series)])):
for idx, srs in enumerate(srs_list):
alpha = 0.7 if style == '--' else 1.0
plt.plot(srs.x, srs.y, linestyle=style, color=colors[idx], label=srs.name, alpha=alpha)
ymax_data.append(srs.y.max())
ymin_data.append(srs.y.min())
plt.grid(linestyle="--")
if args.logy:
plt.yscale('log')
if args.logx:
plt.xscale('log')
if args.ylim:
ymin, ymax = [float(x) for x in args.ylim.split(',')]
plt.ylim(ymin, ymax)
plt.yticks(np.arange(ymin, ymax, (ymax - ymin) / 50))
else:
ymin, ymax = min(ymin_data), max(ymax_data)
plt.yticks(np.arange(ymin, ymax, (ymax - ymin) / 50))
if args.xlim:
xmin, xmax = [float(x) for x in args.xlim.split(',')]
plt.xlim(xmin, xmax)
os.makedirs(GRAPHDIR, exist_ok=True)
fname = args.title + args.series.replace(":", "-").replace(",", "-")
outpath = os.path.join(GRAPHDIR, fname[:100] + '.png')
plt.title(f"{args.series} for {args.model}")
plt.legend()
plt.savefig(outpath)
if args.max:
for idx in range(len(series)):
series[idx].sort(key=lambda x: x.ymax)
for s in series[idx]:
print(s.ymax, s.xmax, s.name)
else:
for idx in range(len(series)):
series[idx].sort(key=lambda x: x.ymin)
for s in series[idx]:
print(s.ymin, s.xmin, s.name)
if args.show:
plt.show()
else:
print('Opening.')
subprocess.call(['open', outpath])