def save()

in mmf/utils/checkpoint.py [0:0]


    def save(self, update, iteration=None, update_best=False):
        # Only save in main process
        # For xla we use xm.save method
        # Which ensures that actual checkpoint saving happens
        # only for the master node.
        # The method also takes care of all the necessary synchronization
        if not is_main() and not is_xla():
            return

        logger.info("Checkpoint save operation started!")
        if not iteration:
            iteration = update

        ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)
        best_ckpt_filepath = os.path.join(
            self.ckpt_foldername, self.ckpt_prefix + "best.ckpt"
        )
        current_ckpt_filepath = os.path.join(
            self.ckpt_foldername, self.ckpt_prefix + "current.ckpt"
        )

        best_iteration = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_iteration
        )
        best_update = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_update
        )
        best_metric = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_value
        )

        model = self.trainer.model
        data_parallel = registry.get("data_parallel") or registry.get("distributed")
        fp16_scaler = getattr(self.trainer, "scaler", None)
        fp16_scaler_dict = None

        if fp16_scaler is not None:
            fp16_scaler_dict = fp16_scaler.state_dict()

        if data_parallel is True:
            model = model.module

        ckpt = {
            "model": model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "current_iteration": iteration,
            "current_epoch": self.trainer.current_epoch,
            "num_updates": update,
            "best_update": best_update,
            "best_metric_value": best_metric,
            "fp16_scaler": fp16_scaler_dict,
            # Convert to container to avoid any dependencies
            "config": OmegaConf.to_container(self.config, resolve=True),
        }

        lr_scheduler = self.trainer.lr_scheduler_callback

        if (
            lr_scheduler is not None
            and getattr(lr_scheduler, "_scheduler", None) is not None
        ):
            lr_scheduler = lr_scheduler._scheduler
            ckpt["lr_scheduler"] = lr_scheduler.state_dict()

        if self.git_repo:
            git_metadata_dict = self._get_vcs_fields()
            ckpt.update(git_metadata_dict)

        with open_if_main(ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        if update_best:
            logger.info("Saving best checkpoint")
            with open_if_main(best_ckpt_filepath, "wb") as f:
                self.save_func(ckpt, f)

        # Save current always

        logger.info("Saving current checkpoint")
        with open_if_main(current_ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        # Save the current checkpoint as W&B artifacts for model versioning.
        if self.config.training.wandb.log_checkpoint:
            logger.info(
                "Saving current checkpoint as W&B Artifacts for model versioning"
            )
            self.trainer.logistics_callback.wandb_logger.log_model_checkpoint(
                current_ckpt_filepath
            )

        # Remove old checkpoints if max_to_keep is set
        # In XLA, only delete checkpoint files in main process
        if self.max_to_keep > 0 and is_main():
            if len(self.saved_iterations) == self.max_to_keep:
                self.remove(self.saved_iterations.pop(0))
            self.saved_iterations.append(update)

        logger.info("Checkpoint save operation finished!")