in src/hyperpod_nemo_adapter/utils/callbacks/checkpoint.py [0:0]
def __init__(self, cfg, *args, **kw):
checkpoint_dir = cfg.exp_manager.get("checkpoint_dir", None)
super().__init__(checkpoint_dir, *args, **kw)
self._resume_from_checkpoint = cfg.exp_manager.get("resume_from_checkpoint", None)
# Full checkpoint
self._save_full_every_n_steps = None
self._save_last_full = None
if "export_full_model" in cfg.exp_manager:
self._save_full_every_n_steps = cfg.exp_manager.export_full_model.get("every_n_train_steps", None)
self._save_last_full = cfg.exp_manager.export_full_model.get("save_last", True)
self._final_full_checkpoint_dir = cfg.exp_manager.export_full_model.get("final_export_dir", None)
# Sharded checkpoint
checkpoint_callback_params = {}
if "checkpoint_callback_params" in cfg.exp_manager:
checkpoint_callback_params = cfg.exp_manager.checkpoint_callback_params
self._save_sharded_every_n_steps = checkpoint_callback_params.get("every_n_train_steps", None)
self._save_last_sharded = checkpoint_callback_params.get("save_last", True)
self._save_top_k = checkpoint_callback_params.get("save_top_k", None)
self._monitor = checkpoint_callback_params.get("monitor", "step")
mode = checkpoint_callback_params.get("mode", "max")
assert mode in [
member.value for member in SageMakerMonitorMode
], f"{mode} is not a valid value for {SageMakerMonitorMode.__name__}"
self._mode = (
SageMakerMonitorMode.MAX if mode == SageMakerMonitorMode.MAX.value.lower() else SageMakerMonitorMode.MIN
)
self._best_k_models = []