in fairseq/checkpoint_utils.py [0:0]
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
from fairseq import meters
# only one worker should attempt to create the required dir
if trainer.data_parallel_rank == 0:
os.makedirs(cfg.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
best_function = max if cfg.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
if cfg.no_save:
return
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
if not trainer.should_save_checkpoint_on_current_rank:
if trainer.always_call_state_dict_during_save_checkpoint:
trainer.state_dict()
return
write_timer = meters.StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
def is_better(a, b):
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
suffix = trainer.checkpoint_suffix
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch
and cfg.save_interval_updates > 0
and updates % cfg.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and cfg.keep_best_checkpoints > 0:
worst_best = getattr(save_checkpoint, "best", None)
chkpts = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if len(chkpts) > 0:
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
# add random digits to resolve ties
with data_utils.numpy_seed(epoch, updates, val_loss):
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
checkpoint_conds[
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
)
] = worst_best is None or is_better(val_loss, worst_best)
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not cfg.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
if cfg.write_checkpoints_asynchronously:
# TODO[ioPath]: Need to implement a delayed asynchronous
# file copying/moving feature.
logger.warning(
f"ioPath is not copying {checkpoints[0]} to {cp} "
"since async write mode is on."
)
else:
assert PathManager.copy(
checkpoints[0], cp, overwrite=True
), f"Failed to copy {checkpoints[0]} to {cp}"
write_timer.stop()
logger.info(
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
checkpoints[0], epoch, updates, val_loss, write_timer.sum
)
)
if not end_of_epoch and cfg.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
if cfg.keep_interval_updates_pattern == -1:
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
)
else:
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
keep_match=True,
)
checkpoints = [
x[0]
for x in checkpoints
if x[1] % cfg.keep_interval_updates_pattern != 0
]
for old_chk in checkpoints[cfg.keep_interval_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
if cfg.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if not cfg.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)