in src/hyperpod_nemo_adapter/collections/parts/sagemaker_trainer_builder.py [0:0]
def _create_checkpoint_callbacks(self):
callbacks = []
if not SUPPORT_CHECKPOINT:
return callbacks
exp_manager = self.cfg.exp_manager
# PEFT checkpointing callback.
if self.cfg.model.peft.peft_type is not None:
if self.use_generic_checkpoint:
callbacks.append(SageMakerCheckpointPeft(self.cfg))
# If using PEFT, do not use regular checkpoint callbacks as they may fail
return callbacks
# Resilience checkpointing callback.
if self.use_resilience_checkpoint:
# If user specify a path to resume, disable auto resume.
enabled_auto_reload = exp_manager.resume_from_checkpoint == None
warmup_steps = exp_manager.auto_checkpoint.warmup_steps
drop_n_warmup_steps = exp_manager.auto_checkpoint.drop_n_warmup_steps
callbacks.append(
SageMakerModelCheckpointResilience(
enable_auto_reload=enabled_auto_reload,
checkpoint_dir=exp_manager.get("checkpoint_dir", None),
warmup_steps=warmup_steps,
drop_n_warmup_steps=drop_n_warmup_steps,
)
)
# Generic checkpointing callback.
if self.use_generic_checkpoint:
callbacks.append(SageMakerCheckpoint(self.cfg))
return callbacks