scripts/create_partial_metrics.py (113 lines of code) (raw):
#!/usr/bin/env python
import logging
logging.basicConfig(level=logging.ERROR)
import csv
import copy
import glob
import os
import grok
import numpy as np
import subprocess
import torch
import sys
from torch.multiprocessing import Process
from grok import trainer
from tqdm import tqdm
from argparse import ArgumentParser
from collections import Counter
from grok_runs import RUNS
from grok_metrics_lib import (
DATA_DIR,
load_metric_data,
get_metric_data,
most_interesting,
)
# Make N_EPOCHS exponentially spaced sets of epochs from 1 to 10,000
N_EPOCHS = 32
BASE = 9999 ** (1.0 / (N_EPOCHS - 1))
epochs = (BASE ** torch.arange(1, N_EPOCHS).float()).long().tolist()
DEFAULT_EPOCHS = ",".join([str(i) for i in epochs])
parser = ArgumentParser()
parser.add_argument("--expt_dir", type=str, help="where to find the runs")
parser.add_argument("--epochs", type=str, default=DEFAULT_EPOCHS)
def child(hparams):
expt_dir = hparams.expt_dir
epochs = [int(e) for e in hparams.epochs.split(",")]
# print("epochs = ", epochs)
device = torch.device(f"cuda:{hparams.gpu}")
ckpt_dir = expt_dir + "/" + "checkpoints"
# ckpt_files = [ckpt_dir + f"/epoch={epoch}.ckpt" for epoch in epochs]
hparams.logdir = expt_dir
results = {
"val_loss": None,
"val_accuracy": None,
}
processed_epochs = []
# with tqdm(epochs, unit="epochs", initial=epochs[0], total=epochs[-1]) as pbar:
# last_epoch = epochs[0]
for idx, epoch in tqdm(list(enumerate(epochs))):
# pbar.update(epoch - last_epoch)
# last_epoch = epoch
ckpt_files = glob.glob(ckpt_dir + f"/epoch={epoch}-step=*.ckpt")
ckpt_files += glob.glob(ckpt_dir + f"/epoch={epoch}.ckpt")
try:
ckpt_file = ckpt_files[-1]
ckpt = torch.load(
ckpt_file,
map_location=f"cuda:{0}", # FIXME
)
processed_epochs.append(epoch)
except FileNotFoundError:
continue
for k, v in ckpt["hyper_parameters"].items():
setattr(hparams, k, v)
new_state_dict = {}
for k, v in ckpt["state_dict"].items():
if k.startswith("transformer."):
new_state_dict[k] = v
else:
new_state_dict["transformer." + k] = v
ckpt["state_dict"] = new_state_dict
model = trainer.TrainableTransformer(hparams).float()
model.load_state_dict(ckpt["state_dict"])
model = model.to(device).eval()
dl = model.test_dataloader()
dl.reset_iteration(shuffle=False)
outputs = [model.test_step(batch, idx) for (idx, batch) in enumerate(dl)]
r = model.test_epoch_end(outputs)["log"]
if results["val_loss"] is None:
results["val_loss"] = r["test_loss"].squeeze().unsqueeze(0)
results["val_accuracy"] = r["test_accuracy"].squeeze().unsqueeze(0)
else:
results["val_loss"] = torch.cat(
[results["val_loss"], r["test_loss"].squeeze().unsqueeze(0)], dim=0
)
results["val_accuracy"] = torch.cat(
[
results["val_accuracy"],
r["test_accuracy"].squeeze().unsqueeze(0),
],
dim=0,
)
for k, v in results.items():
results[k] = v.to("cpu")
results["epochs"] = torch.LongTensor(processed_epochs, device="cpu")
results["dl"] = dl
os.makedirs(expt_dir + "/activations", exist_ok=True)
ptfile = (
expt_dir + f"/activations/activations_{epochs[0]:010d}_{epochs[-1]:010d}.pt"
)
torch.save(results, ptfile)
if __name__ == "__main__":
hparams = trainer.get_args(parser)
if hparams.expt_dir is not None:
child(hparams)
else:
for operation in RUNS:
print(f"running {operation}")
ds_len, run = RUNS[operation]
data = load_metric_data(
f"{DATA_DIR}/{run}", epochs=10000, load_partial_data=False
)
metric_data = get_metric_data(data)
metric_data = most_interesting(metric_data)
for arch in metric_data:
interesting_t = int(metric_data[arch]["T"][0].item())
expt = f"{arch}_T-{interesting_t}"
print(f"--> expt {expt}")
glb = f"{DATA_DIR}/{run}/{expt}_*"
# print(f"glb {glb}")
expt_dir = glob.glob(glb)[0]
cmd = [sys.argv[0], "--expt_dir", expt_dir]
subprocess.run(cmd, check=False, shell=False)
# child(hparams)