in bring-your-own-container/fairseq_translation/fairseq/train_driver.py [0:0]
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
not end_of_epoch
and args.save_interval_updates > 0
and updates % args.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
not hasattr(save_checkpoint, "best") or val_loss < save_checkpoint.best
)
checkpoint_conds["checkpoint_last.pt"] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
"best": save_checkpoint.best,
"train_iterator": epoch_itr.state_dict(),
"val_loss": val_loss,
}
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt")
for old_chk in checkpoints[args.keep_interval_updates :]:
os.remove(old_chk)