def _should_save_sharded()

in src/hyperpod_nemo_adapter/utils/callbacks/checkpoint.py [0:0]


    def _should_save_sharded(self, trainer: "pl.Trainer", monitor_candidates):
        """
        Make sure we need to save if all criterias are met:
        1. Hit every n steps
        2. Have the value in metric logged
        3. The new score is better.
        """
        save_last_step = trainer.max_steps == trainer.global_step and self._save_last_sharded
        is_sharded_on = self._save_sharded_every_n_steps >= 1 and self._save_top_k >= 1
        is_every_n = is_sharded_on and (trainer.global_step % self._save_sharded_every_n_steps == 0)
        # Neither saving last step nor every n step is needed.
        if not save_last_step and not is_every_n:
            return False
        has_value = self._monitor in monitor_candidates
        if not has_value:
            m = (
                f"`SageMakerModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
                f" metrics: {list(monitor_candidates)}."
                f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
            )
            logging.warn(m)
            return False

        # Check if it hits topk capacity. if hits, check if it is one of the topk.
        is_top_k = len(self._best_k_models) < self._save_top_k
        lowest = -math.inf if self._mode == SageMakerMonitorMode.MAX else math.inf
        if len(self._best_k_models) == self._save_top_k and has_value:
            if len(self._best_k_models):
                lowest = self._best_k_models[-1].score
            else:
                lowest = -math.inf if self._mode == SageMakerMonitorMode.MAX else math.inf
            is_top_k = (
                lowest < monitor_candidates[self._monitor]
                if self._mode == SageMakerMonitorMode.MAX
                else lowest > monitor_candidates[self._monitor]
            )
        if not is_top_k:
            return False

        return True