def save_model_checkpoint()

in src/util.py [0:0]


def save_model_checkpoint(model, step, config,
                          dev_loss_full, dev_ppl_full,
                          dev_loss_head, dev_ppl_head,
                          dev_loss_tail, dev_ppl_tail):
    """Saves out model artifact along with basic statistics about checkpoint"""
    experiment_directory = config.get("EXPERIMENT", "experiment_directory")
    checkpoints_dir = os.path.join(experiment_directory, "checkpoints")

    if not os.path.exists(checkpoints_dir):
        os.mkdir(checkpoints_dir)

    curr_checkpoint = os.path.join(checkpoints_dir, f"checkpoint_step_{step}")
    os.mkdir(curr_checkpoint)
    with open(os.path.join(curr_checkpoint, "info.txt"), "w") as f_out :
        f_out.write(f"Step {step}\n")
        f_out.write(f"Full Dev Loss: {dev_loss_full} - Dev PPL {dev_ppl_full}\n")
        f_out.write(f"Head Dev Loss: {dev_loss_head} - Dev PPL {dev_ppl_head}\n")
        f_out.write(f"Tail Dev Loss: {dev_loss_tail} - Dev PPL {dev_ppl_tail}")

    torch.save(model, os.path.join(curr_checkpoint, "model.pt"))