scripts/visualize_metrics.py (439 lines of code) (raw):
#!/usr/bin/env python
# coding: utf-8
import csv
import json
import logging
import os
import subprocess
from argparse import ArgumentParser
from copy import deepcopy
from glob import glob
from pprint import pprint
import blobfile as bf
import grok
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import torch
import yaml
from tqdm import tqdm
logger = logging.getLogger(__name__)
# take args: input_dir output_dir
parser = ArgumentParser()
parser.add_argument(
"-i",
"--input_dir",
type=str,
required=True,
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
required=True,
)
parser = grok.training.add_args(parser)
args = parser.parse_args()
print(args, flush=True)
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
def load_expt_metrics(
expt_dir,
args,
):
"""load the metrics for one experiment"""
args = deepcopy(args)
# load the hparams for this experiment
with open(f"{expt_dir}/default/version_0/hparams.yaml", "r") as fh:
hparams_dict = yaml.safe_load(fh)
for k, v in hparams_dict.items():
setattr(args, k, v)
# load the summarized validation and training data for every epoch
val_data = {
"step": [],
"epoch": [],
"val_loss": [],
"val_accuracy": [],
}
train_data = {
"step": [],
"epoch": [],
"train_loss": [],
"train_accuracy": [],
"learning_rate": [],
}
with open(f"{expt_dir}/default/version_0/metrics.csv", "r") as fh:
for row in csv.DictReader(fh):
if row["train_loss"] != "":
for k in train_data:
if k in ["step", "epoch"]:
v = int(row[k])
else:
v = float(row[k])
train_data[k].append(v)
else:
for k in val_data:
if k in ["step", "epoch"]:
v = int(row[k])
else:
v = float(row[k])
val_data[k].append(v)
return {
"hparams": hparams_dict,
"train": train_data,
"val": val_data,
# "raw": raw_data,
}
def load_run_metrics(
run_dir,
args=args,
):
"""load all the metrics for a collection of experiments with the same architecture
across various amounts of training data"""
metric_data = {}
from os import walk
_, expt_dirs, _ = next(os.walk(run_dir))
for expt_dir in tqdm(expt_dirs, unit="expt"):
try:
expt_data = load_expt_metrics(f"{run_dir}/{expt_dir}", args)
train_data_pct = expt_data["hparams"]["train_data_pct"]
metric_data[train_data_pct] = expt_data
except FileNotFoundError:
pass
return metric_data
def add_metric_graph(
fig,
ax,
arch,
metric,
metric_data,
scales,
cmap="viridis",
by="step", # step or epoch
max_increment=0,
):
ax.set_title(metric)
ax.set_xscale(scales["x"])
ax.set_yscale(scales["y"])
ax.set_xlabel(by)
if "accuracy" in metric:
ax.yaxis.set_major_formatter(mtick.PercentFormatter())
ymin = 1e-16
ymax = 101
ax.axis(ymin=ymin, ymax=ymax)
if "loss" in metric:
ymin = 1e-16
ymax = 15
ax.axis(ymin=ymin, ymax=ymax)
total_plots = 0
logger.debug(f"processing {metric}")
plots = []
T = list(sorted(metric_data.keys()))
T_max = int(T[-1])
T_min = int(T[0])
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=T[0], vmax=T[-1]))
colors = sm.to_rgba(T)
for i, t in enumerate(T):
if "val" in metric:
this_data = metric_data[t]["val"]
else:
this_data = metric_data[t]["train"]
X = this_data[by]
Y = this_data[metric]
if max_increment > 0:
X = [x for x in X if x <= max_increment]
Y = Y[: len(X)]
if len(X) != len(Y):
logger.warning(f"Mismatched data: {metric} at t={t}")
continue
if not Y:
logger.warning(f"No data for {metric}i at t={t}")
continue
label = arch + f" t={t}"
if "accuracy" in metric:
label += " (max = %.2f)" % max(Y)
elif "loss" in metric:
label += " (min = %.2f)" % min(Y)
total_plots += 1
ax.plot(X, Y, label=label, color=colors[i])
if T_max - T_min <= 10:
ax.legend()
else:
fig.colorbar(
sm,
ax=ax,
label="% training data",
ticks=range(T_min, T_max + 1, int((T_max - T_min) / 5)),
)
def add_max_accuracy_graph(
ax,
arch,
metric,
metric_data,
scales,
by="step",
max_increment=0,
):
ax.set_title(f"max {metric}")
ax.set_xlabel("% of total data trained on")
ax.xaxis.set_major_formatter(mtick.PercentFormatter())
xmin = 0
xmax = 100
ymin = 1e-16
ymax = 101
ax.axis(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
ax.set_xscale(scales["x"])
ax.set_yscale(scales["y"])
ax.yaxis.set_major_formatter(mtick.PercentFormatter())
ax.xaxis.set_major_formatter(mtick.PercentFormatter())
T = list(sorted(metric_data.keys()))
T_max = int(T[-1])
T_min = int(T[0])
Y = []
for i, t in enumerate(T):
if "val" in metric:
this_data = metric_data[t]["val"]
else:
this_data = metric_data[t]["train"]
X = this_data[by]
if max_increment > 0:
X = [x for x in X if x <= max_increment]
max_idx = len(X)
else:
max_idx = -1
try:
Y.append(max(this_data[metric][:max_idx]))
except ValueError:
Y.append(np.nan)
ax.set_xticks(np.arange(0, 100, 5))
label = f"max {metric} {arch}"
ax.plot(T, Y, label=label)
def create_loss_curves(
metric_data,
arch,
operation,
# epochs,
most_interesting_only=False,
image_dir=args.output_dir,
by="step",
max_increment=0,
cmap="viridis",
):
scales = {
"x": "log",
"y": "linear",
}
ncols = 2
nrows = 3
fig_width = ncols * 8
fig_height = nrows * 5
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
add_metric_graph(
fig,
axs[0, 0],
arch,
"val_loss",
metric_data,
scales,
cmap,
by,
max_increment=max_increment,
)
add_metric_graph(
fig,
axs[0, 1],
arch,
"val_accuracy",
metric_data,
scales,
cmap,
by,
max_increment=max_increment,
)
add_metric_graph(
fig,
axs[1, 0],
arch,
"train_loss",
metric_data,
scales,
cmap,
by,
max_increment=max_increment,
)
add_metric_graph(
fig,
axs[1, 1],
arch,
"train_accuracy",
metric_data,
scales,
cmap,
by,
max_increment=max_increment,
)
add_metric_graph(
fig,
axs[2, 0],
arch,
"learning_rate",
metric_data,
scales,
cmap,
by,
max_increment=max_increment,
)
fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s")
fig.tight_layout()
img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}__upto_{max_increment:010d}_{by}"
if most_interesting_only:
img_file += "_most_interesting"
img_file += ".png"
d = os.path.split(img_file)[0]
os.makedirs(d, exist_ok=True)
print(f"Writing {img_file}")
fig.savefig(img_file)
plt.close(fig)
def create_max_accuracy_curves(
metric_data,
arch,
operation,
by="step",
max_increment=0,
image_dir=args.output_dir,
):
scales = {
"x": "linear",
"y": "linear",
}
ncols = 1
nrows = 2
fig_width = ncols * 8
fig_height = nrows * 5
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
add_max_accuracy_graph(
axs[0],
arch,
"val_accuracy",
metric_data,
scales,
by=by,
max_increment=max_increment,
)
axs[0].legend()
add_max_accuracy_graph(
axs[1],
arch,
"train_accuracy",
metric_data,
scales,
by=by,
max_increment=max_increment,
)
axs[1].legend()
fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s")
fig.tight_layout()
img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}_upto_{max_increment:010d}_{by}.png"
d = os.path.split(img_file)[0]
os.makedirs(d, exist_ok=True)
print(f"Writing {img_file}")
fig.savefig(img_file)
plt.close(fig)
def create_tsne_graphs(
operation,
expt,
run_dir,
image_dir=args.output_dir,
):
saved_pt_dir = f"{run_dir}/activations"
saved_pts = []
loss_ts = []
accuracy_ts = []
epochs_ts = []
print(f'glob = {saved_pt_dir + "/activations_*.pt"}')
files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt"))
print(f"files = {files}")
for file in files:
print(f"Loading {file}")
saved_pt = torch.load(file)
saved_pts.append(saved_pt)
loss_ts.append(saved_pt["val_loss"].mean(dim=-1))
accuracy_ts.append(saved_pt["val_accuracy"])
epochs_ts.append(saved_pt["epochs"].squeeze())
loss_t = torch.cat(loss_ts, dim=0).T.detach()
accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach()
epochs_t = torch.cat(epochs_ts, dim=0).detach()
print(loss_t.shape)
print(accuracy_t.shape)
print(epochs_t.shape)
######
a = 0
num_eqs = len(loss_t)
b = a + num_eqs
print("Doing T-SNE..")
loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t)
print("...done T-SNE.")
ncols = 1
nrows = 1
fig_width = ncols * 8
fig_height = nrows * 5
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1])
img_file = f"{image_dir}/tsne/{operation}_{expt}.png"
d = os.path.split(img_file)[0]
os.makedirs(d, exist_ok=True)
print(f"Writing {img_file}")
fig.savefig(img_file)
plt.close(fig)
def get_arch(metric_data):
k = list(metric_data.keys())[0]
hparams = metric_data[k]["hparams"]
arch = f'L-{hparams["n_layers"]}_H-{hparams["n_heads"]}_D-{hparams["d_model"]}_B-{hparams["batchsize"]}_S-{hparams["random_seed"]}_DR-{hparams["dropout"]}'
return arch
def get_operation(metric_data):
k = list(metric_data.keys())[0]
hparams = metric_data[k]["hparams"]
operator = hparams["math_operator"]
operand_length = hparams["operand_length"]
_, operation = grok.data.ArithmeticDataset.get_file_path(operator, operand_length)
return operation
def get_max_epochs(metric_data):
k = list(metric_data.keys())[0]
hparams = metric_data[k]["hparams"]
return hparams["max_epochs"]
rundir = args.input_dir
try:
metric_data = load_run_metrics(rundir, args)
arch = get_arch(metric_data)
operation = get_operation(metric_data)
max_epochs = get_max_epochs(metric_data)
for by in ["step", "epoch"]:
create_loss_curves(metric_data, arch, operation, by=by)
by = "epoch"
last_i = -1
for i in sorted(list(set(2 ** (np.arange(167) / 10)))):
if i > max_epochs:
break
i = int(round(i))
create_max_accuracy_curves(
metric_data,
arch,
operation,
by=by,
max_increment=i,
)
# make a video
in_files = os.path.join(
args.output_dir,
"max_accuracy",
f"{operation}_max_accuracy_{arch}_upto_%*.png",
)
out_file = os.path.join(args.output_dir, f"{operation}_{arch}_max_accuracy.mp4")
cmd = [
"ffmpeg",
"-y",
"-r",
"16",
"-i",
in_files,
"-vcodec",
"libx264",
"-crf",
"25",
"-pix_fmt",
"yuv420p",
out_file,
]
subprocess.check_call(cmd)
except BaseException as e:
print(f"{rundir} failed: {e}")