in src/hyperpod_nemo_adapter/utils/callbacks/checkpoint.py [0:0]
def _should_save_sharded(self, trainer: "pl.Trainer", monitor_candidates):
"""
Make sure we need to save if all criterias are met:
1. Hit every n steps
2. Have the value in metric logged
3. The new score is better.
"""
save_last_step = trainer.max_steps == trainer.global_step and self._save_last_sharded
is_sharded_on = self._save_sharded_every_n_steps >= 1 and self._save_top_k >= 1
is_every_n = is_sharded_on and (trainer.global_step % self._save_sharded_every_n_steps == 0)
# Neither saving last step nor every n step is needed.
if not save_last_step and not is_every_n:
return False
has_value = self._monitor in monitor_candidates
if not has_value:
m = (
f"`SageMakerModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
f" metrics: {list(monitor_candidates)}."
f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
)
logging.warn(m)
return False
# Check if it hits topk capacity. if hits, check if it is one of the topk.
is_top_k = len(self._best_k_models) < self._save_top_k
lowest = -math.inf if self._mode == SageMakerMonitorMode.MAX else math.inf
if len(self._best_k_models) == self._save_top_k and has_value:
if len(self._best_k_models):
lowest = self._best_k_models[-1].score
else:
lowest = -math.inf if self._mode == SageMakerMonitorMode.MAX else math.inf
is_top_k = (
lowest < monitor_candidates[self._monitor]
if self._mode == SageMakerMonitorMode.MAX
else lowest > monitor_candidates[self._monitor]
)
if not is_top_k:
return False
return True