in src/util.py [0:0]
def save_model_checkpoint(model, step, config,
dev_loss_full, dev_ppl_full,
dev_loss_head, dev_ppl_head,
dev_loss_tail, dev_ppl_tail):
"""Saves out model artifact along with basic statistics about checkpoint"""
experiment_directory = config.get("EXPERIMENT", "experiment_directory")
checkpoints_dir = os.path.join(experiment_directory, "checkpoints")
if not os.path.exists(checkpoints_dir):
os.mkdir(checkpoints_dir)
curr_checkpoint = os.path.join(checkpoints_dir, f"checkpoint_step_{step}")
os.mkdir(curr_checkpoint)
with open(os.path.join(curr_checkpoint, "info.txt"), "w") as f_out :
f_out.write(f"Step {step}\n")
f_out.write(f"Full Dev Loss: {dev_loss_full} - Dev PPL {dev_ppl_full}\n")
f_out.write(f"Head Dev Loss: {dev_loss_head} - Dev PPL {dev_ppl_head}\n")
f_out.write(f"Tail Dev Loss: {dev_loss_tail} - Dev PPL {dev_ppl_tail}")
torch.save(model, os.path.join(curr_checkpoint, "model.pt"))