in grok/training.py [0:0]
def compute_sharpness(hparams: Namespace, ckpts) -> None:
"""
This is the compute_sharpness method. This loads a series of checkpoints in
the defined hyperparameters
:param hparams: An argparse.Namespace with all of the relevant hyperparameters
"""
# Process the args
if hparams.logdir is None:
hparams.logdir = os.environ.get("LOGDIR", ".")
hparams.logdir = os.path.abspath(hparams.logdir)
# Make sure d_model, heads, and d_key are compatible
assert (
hparams.d_model % hparams.n_heads == 0
), "n_heads=%s does not evenly divide d_model=%s" % (
hparams.n_heads,
hparams.d_model,
)
hparams.d_key = hparams.d_model / hparams.n_heads
# Set up the RNGs for repeatability
if hparams.random_seed != -1:
torch.manual_seed(hparams.random_seed)
torch.cuda.manual_seed(hparams.random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
checkpoint_path = hparams.logdir + "/checkpoints"
os.makedirs(checkpoint_path, exist_ok=True)
hparams.checkpoint_path = checkpoint_path
# Create the model
model = TrainableTransformer(hparams).float()
torch.save(model, os.path.join(checkpoint_path, "init.pt"))
logger = CSVLogger(hparams.logdir)
trainer_args = {
"max_steps": hparams.max_steps,
"min_steps": hparams.max_steps,
"max_epochs": int(1e8),
"val_check_interval": 1,
"profiler": False,
# "checkpoint_callback": checkpointer,
"logger": logger,
"log_every_n_steps": 1,
"flush_logs_every_n_steps": 1000,
}
if torch.cuda.is_available() and hparams.gpu >= 0:
trainer_args["gpus"] = [hparams.gpu]
trainer = Trainer(**trainer_args)
for ckpt in ckpts:
print(f"Loading checkpoint {ckpt}")
# model = torch.load(ckpt)
# model.load_state_dict(torch.load(ckpt))
checkpoint = torch.load(ckpt)
# print(dir(checkpoint), type(checkpoint), "Ckpt")
# for k, v in checkpoint.items():
# print(k)
# print(checkpoint["hyper_parameters"])
hps = checkpoint["hyper_parameters"]
hps = argparse.Namespace(**hps)
model = TrainableTransformer(hps).float()
model.load_state_dict(checkpoint["state_dict"])
phi = get_sharpness(model.train_dataloader(), model)
results = {}
results[ckpt] = phi
pickle.dump(results, open(f"results/results_SD-{i}.pkl", "wb"))