def save_checkpoint()

in archived/fairseq_sagemaker_translate_en2fr/fairseq/train_driver.py [0:0]


def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
        return
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0
    )
    checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
        not end_of_epoch
        and args.save_interval_updates > 0
        and updates % args.save_interval_updates == 0
    )
    checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
        not hasattr(save_checkpoint, "best") or val_loss < save_checkpoint.best
    )
    checkpoint_conds["checkpoint_last.pt"] = True  # keep this last so that it's a symlink

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
    extra_state = {
        "best": save_checkpoint.best,
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss,
    }

    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt")
        for old_chk in checkpoints[args.keep_interval_updates :]:
            os.remove(old_chk)