in mmf/utils/checkpoint.py [0:0]
def save(self, update, iteration=None, update_best=False):
# Only save in main process
if not is_master():
return
if not iteration:
iteration = update
ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)
best_ckpt_filepath = os.path.join(
self.ckpt_foldername, self.ckpt_prefix + "best.ckpt"
)
current_ckpt_filepath = os.path.join(
self.ckpt_foldername, self.ckpt_prefix + "current.ckpt"
)
best_iteration = (
self.trainer.early_stop_callback.early_stopping.best_monitored_iteration
)
best_update = (
self.trainer.early_stop_callback.early_stopping.best_monitored_update
)
best_metric = (
self.trainer.early_stop_callback.early_stopping.best_monitored_value
)
model = self.trainer.model
data_parallel = registry.get("data_parallel") or registry.get("distributed")
if data_parallel is True:
model = model.module
ckpt = {
"model": model.state_dict(),
"optimizer": self.trainer.optimizer.state_dict(),
"best_iteration": best_iteration,
"current_iteration": iteration,
"current_epoch": self.trainer.current_epoch,
"num_updates": update,
"best_update": best_update,
"best_metric_value": best_metric,
# Convert to container to avoid any dependencies
"config": OmegaConf.to_container(self.config, resolve=True),
}
lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
if lr_scheduler is not None:
ckpt["lr_scheduler"] = lr_scheduler.state_dict()
if self.git_repo:
git_metadata_dict = self._get_vcs_fields()
ckpt.update(git_metadata_dict)
with PathManager.open(ckpt_filepath, "wb") as f:
torch.save(ckpt, f)
if update_best:
with PathManager.open(best_ckpt_filepath, "wb") as f:
torch.save(ckpt, f)
# Save current always
with PathManager.open(current_ckpt_filepath, "wb") as f:
torch.save(ckpt, f)
# Remove old checkpoints if max_to_keep is set
if self.max_to_keep > 0:
if len(self.saved_iterations) == self.max_to_keep:
self.remove(self.saved_iterations.pop(0))
self.saved_iterations.append(update)