scripts/create_metric_graphs.py (218 lines of code) (raw):

#!/usr/bin/env python # coding: utf-8 # Render metrics graphs import csv import logging import os import glob import socket from argparse import ArgumentParser from collections import defaultdict import matplotlib.pyplot as plt import matplotlib.ticker as mtick import numpy as np import torch from mpl_toolkits.axes_grid1 import make_axes_locatable from tqdm import tqdm from sklearn.manifold import TSNE import grok from grok.visualization import * # from grok_runs import RUNS logging.basicConfig(level=logging.ERROR) logger = logging.getLogger("grok.view_metrics") logger.setLevel(logging.ERROR) RUNS = { "subtraction": ( 9409, "subtraction/2021-02-05-03-33-56-alethea-sjjf", ), } limits = { "min_val_accuracy": 0, "max_val_accuracy": 100, "min_T": 0, # 0 "max_T": 100, # 87.5 "min_D": 0, # 8 "max_D": 256, # 256 "min_H": 0, # 1 "max_H": 4, # 8 "min_L": 0, # 1 "max_L": 4, # 4 "min_accuracy": 0, "max_accuracy": 100, } for k in limits.keys(): metric = k.replace("min_", "").replace("max_", "") assert ( limits["max_" + metric] >= limits["min_" + metric] ), f"invalid {metric} limits" parser = ArgumentParser() parser.add_argument("-i", "--image_dir", type=str, default=IMAGE_DIR) args = parser.parse_args() def create_loss_curves( metric_data, epochs, run, most_interesting_only=False, image_dir=args.image_dir, ds_len=None, cmap=DEFAULT_CMAP, ): scales = { "x": "log", "y": "linear", } arch = list(metric_data.keys())[0] ncols = 2 nrows = 3 fig_width = ncols * 8 fig_height = nrows * 5 fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) add_metric_graph( fig, axs[0, 0], "val_loss", metric_data, scales, cmap=cmap, ds_len=ds_len ) add_metric_graph( fig, axs[0, 1], "val_accuracy", metric_data, scales, cmap, ds_len=ds_len ) add_metric_graph( fig, axs[1, 0], "train_loss", metric_data, scales, cmap, ds_len=ds_len ) add_metric_graph( fig, axs[1, 1], "train_accuracy", metric_data, scales, cmap, ds_len=ds_len ) add_metric_graph( fig, axs[2, 0], "learning_rate", metric_data, scales, cmap, # ds_len=ds_len ) fig.suptitle(f"{operation} {list(data.keys())[0]}") fig.tight_layout() img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}" if ds_len is not None: img_file += "_by_update" if most_interesting_only: img_file += "_most_interesting" img_file += ".png" d = os.path.split(img_file)[0] os.makedirs(d, exist_ok=True) print(f"Writing {img_file}") fig.savefig(img_file) plt.close(fig) def create_max_accuracy_curves( metric_data, epochs, run, image_dir=args.image_dir, ds_len=None ): scales = { "x": "linear", "y": "linear", } ncols = 1 nrows = 2 fig_width = ncols * 8 fig_height = nrows * 5 fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) def get_ax(row=0, col=0, nrows=nrows, ncols=ncols, axs=axs): if nrows == 0: if ncols == 1: return axs else: return axs[col] else: if ncols == 1: return axs[row] else: return axs[row, col] add_extremum_graph( get_ax(0, 0), "val_accuracy", "max", metric_data, show_legend=False ) add_extremum_graph( get_ax(1, 0), "train_accuracy", "max", metric_data, show_legend=False ) fig.suptitle(f"{operation} {list(data.keys())[0]}") fig.tight_layout() expt = list(metric_data.keys())[0] img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}.png" d = os.path.split(img_file)[0] os.makedirs(d, exist_ok=True) print(f"Writing {img_file}") fig.savefig(img_file) plt.close(fig) def create_tsne_graphs(operation, expt, run_dir, image_dir=args.image_dir): saved_pt_dir = f"{run_dir}/activations" saved_pts = [] loss_ts = [] accuracy_ts = [] epochs_ts = [] print(f'glob = {saved_pt_dir + "/activations_*.pt"}') files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt")) print(f"files = {files}") for file in files: print(f"Loading {file}") saved_pt = torch.load(file) saved_pts.append(saved_pt) loss_ts.append(saved_pt["val_loss"].mean(dim=-1)) accuracy_ts.append(saved_pt["val_accuracy"]) epochs_ts.append(saved_pt["epochs"].squeeze()) loss_t = torch.cat(loss_ts, dim=0).T.detach() accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach() epochs_t = torch.cat(epochs_ts, dim=0).detach() print(loss_t.shape) print(accuracy_t.shape) print(epochs_t.shape) ###### a = 0 num_eqs = len(loss_t) b = a + num_eqs print("Doing T-SNE..") loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t) print("...done T-SNE.") ncols = 1 nrows = 1 fig_width = ncols * 8 fig_height = nrows * 5 fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1]) img_file = f"{image_dir}/tsne/{operation}_{expt}.png" d = os.path.split(img_file)[0] os.makedirs(d, exist_ok=True) print(f"Writing {img_file}") fig.savefig(img_file) plt.close(fig) for operation in RUNS: print("") print("") print(f"Processing {operation}", flush=True) if operation.endswith("-epochs"): epochs = int(operation.split("/")[-1].split("-")[0]) else: epochs = 5000 #### ds_len, run = RUNS[operation] data = load_metric_data(f"{DATA_DIR}/{run}", epochs=epochs, load_partial_data=False) # check it for arch in data: # print(data[arch]["metrics"].shape) metrics, expts, epochs = data[arch]["metrics"].shape message = ( f"{arch} : loaded {metrics} metrics, {expts} experiments, {epochs} epochs" ) assert metrics == 5, "INVALID metrics count: " + message assert expts < 88, "INVALID experiments count: " + message assert epochs == epochs, f"INVALID epochs count: " + message print(message) # ## Set filters on the data to view metric_data = get_metric_data(data, limits) # Draw loss and accuracy curves create_max_accuracy_curves(metric_data, epochs, run) create_loss_curves(metric_data, epochs, run) create_loss_curves(metric_data, epochs, run, ds_len=ds_len) most_interesting_metric_data = most_interesting(metric_data) create_loss_curves( most_interesting_metric_data, epochs, run, most_interesting_only=True ) create_loss_curves( most_interesting_metric_data, epochs, run, most_interesting_only=True, ds_len=ds_len, ) # Draw max accuracy curves # T-SNE of loss curves: try: for arch in most_interesting_metric_data: t = int(most_interesting_metric_data[arch]["T"][0].item()) expt = f"{arch}_T-{t}_DROP-0.0" create_tsne_graphs(operation, expt, run_dir=f"{DATA_DIR}/{run}/{expt}") except: print("TSNE failed")