in grok/training.py [0:0]
def train(hparams: Namespace) -> None:
"""
This is the main trainer_method. This sets up and runs experiment with
the defined hyperparameters
:param hparams: An argparse.Namespace with all of the relevant hyperparameters
"""
# Process the args
if hparams.logdir is None:
hparams.logdir = os.environ.get("LOGDIR", ".")
hparams.logdir = os.path.abspath(hparams.logdir)
# Make sure d_model, heads, and d_key are compatible
assert (
hparams.d_model % hparams.n_heads == 0
), "n_heads=%s does not evenly divide d_model=%s" % (
hparams.n_heads,
hparams.d_model,
)
hparams.d_key = hparams.d_model / hparams.n_heads
# Set up the RNGs for repeatability
if hparams.random_seed != -1:
torch.manual_seed(hparams.random_seed)
torch.cuda.manual_seed(hparams.random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
checkpoint_path = hparams.logdir + "/checkpoints"
os.makedirs(checkpoint_path, exist_ok=True)
hparams.checkpoint_path = checkpoint_path
# Create the model
model = TrainableTransformer(hparams).float()
torch.save(model, os.path.join(checkpoint_path, "init.pt"))
logger = CSVLogger(hparams.logdir)
# checkpointer = ModelCheckpoint(
# filepath=checkpoint_path,
# monitor="save_ckpt",
# mode="max",
# save_top_k=len(hparams.ckpt_epochs),
# verbose=False,
# )
trainer_args = {
"max_steps": hparams.max_steps,
"min_steps": hparams.max_steps,
"max_epochs": int(1e8),
"val_check_interval": 1,
"profiler": False,
# "checkpoint_callback": checkpointer,
"logger": logger,
"log_every_n_steps": 1,
"flush_logs_every_n_steps": 1000,
}
if torch.cuda.is_available() and hparams.gpu >= 0:
trainer_args["gpus"] = [hparams.gpu]
trainer = Trainer(**trainer_args)
trainer.fit(model=model) # type: ignore
"""
margin = np.percentile(model.margin.detach().cpu().numpy(), 5)
device = transformer.embedding.weight.device
measures, bounds = metrics.calculate(
transformer,
transformer_init.to(device),
device,
dataset_size,
margin,
input_dim=hparams.d_model,
)
measures_file = os.path.join(logger.log_dir, "measures.json")
bounds_file = os.path.join(logger.log_dir, "bounds.json")
with open(measures_file, "w") as fh:
json.dump(measures, fh)
with open(bounds_file, "w") as fh:
json.dump(bounds, fh)
"""
return hparams.logdir