in mmf/utils/checkpoint.py [0:0]
def save(self, update, iteration=None, update_best=False):
# Only save in main process
# For xla we use xm.save method
# Which ensures that actual checkpoint saving happens
# only for the master node.
# The method also takes care of all the necessary synchronization
if not is_main() and not is_xla():
return
logger.info("Checkpoint save operation started!")
if not iteration:
iteration = update
ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)
best_ckpt_filepath = os.path.join(
self.ckpt_foldername, self.ckpt_prefix + "best.ckpt"
)
current_ckpt_filepath = os.path.join(
self.ckpt_foldername, self.ckpt_prefix + "current.ckpt"
)
best_iteration = (
self.trainer.early_stop_callback.early_stopping.best_monitored_iteration
)
best_update = (
self.trainer.early_stop_callback.early_stopping.best_monitored_update
)
best_metric = (
self.trainer.early_stop_callback.early_stopping.best_monitored_value
)
model = self.trainer.model
data_parallel = registry.get("data_parallel") or registry.get("distributed")
fp16_scaler = getattr(self.trainer, "scaler", None)
fp16_scaler_dict = None
if fp16_scaler is not None:
fp16_scaler_dict = fp16_scaler.state_dict()
if data_parallel is True:
model = model.module
ckpt = {
"model": model.state_dict(),
"optimizer": self.trainer.optimizer.state_dict(),
"best_iteration": best_iteration,
"current_iteration": iteration,
"current_epoch": self.trainer.current_epoch,
"num_updates": update,
"best_update": best_update,
"best_metric_value": best_metric,
"fp16_scaler": fp16_scaler_dict,
# Convert to container to avoid any dependencies
"config": OmegaConf.to_container(self.config, resolve=True),
}
lr_scheduler = self.trainer.lr_scheduler_callback
if (
lr_scheduler is not None
and getattr(lr_scheduler, "_scheduler", None) is not None
):
lr_scheduler = lr_scheduler._scheduler
ckpt["lr_scheduler"] = lr_scheduler.state_dict()
if self.git_repo:
git_metadata_dict = self._get_vcs_fields()
ckpt.update(git_metadata_dict)
with open_if_main(ckpt_filepath, "wb") as f:
self.save_func(ckpt, f)
if update_best:
logger.info("Saving best checkpoint")
with open_if_main(best_ckpt_filepath, "wb") as f:
self.save_func(ckpt, f)
# Save current always
logger.info("Saving current checkpoint")
with open_if_main(current_ckpt_filepath, "wb") as f:
self.save_func(ckpt, f)
# Save the current checkpoint as W&B artifacts for model versioning.
if self.config.training.wandb.log_checkpoint:
logger.info(
"Saving current checkpoint as W&B Artifacts for model versioning"
)
self.trainer.logistics_callback.wandb_logger.log_model_checkpoint(
current_ckpt_filepath
)
# Remove old checkpoints if max_to_keep is set
# In XLA, only delete checkpoint files in main process
if self.max_to_keep > 0 and is_main():
if len(self.saved_iterations) == self.max_to_keep:
self.remove(self.saved_iterations.pop(0))
self.saved_iterations.append(update)
logger.info("Checkpoint save operation finished!")