in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def _get_max_steps(self):
"""
Compute the maximum number of training steps (-1 if it cannot be computed).
Over write nemo's _get_max_steps with
1. Override max step from config lr_decay_iters
2. Get data loader length from datamodule
"""
if self._cfg.lr_decay_iters is not None:
return self._cfg.lr_decay_iters
if getattr(self, "_trainer", None) is None:
_logger.warning("Cannot compute `max_steps` as no trainer is set")
return -1
if self._trainer.max_steps >= 0:
# Note that when `trainer.max_steps` is defined, we ignore
# `max_epochs` (even if training may end before `max_steps` is
# reached due to `max_epochs`). This is for backward compatibility
# with older versions of NeMo.
if self._trainer.max_epochs is not None and self._trainer.max_epochs >= 0:
_logger.warning(
"Ignoring `trainer.max_epochs` when computing `max_steps` "
"because `trainer.max_steps` is already "
f"set to {self._trainer.max_steps}."
)
return self._trainer.max_steps
if self._trainer.max_epochs is None or self._trainer.max_epochs < 0:
_logger.warning("Cannot compute `max_steps` if neither `trainer.max_steps` nor `trainer.max_epochs` is set")
return -1
if getattr(self, "_train_dl", None) is None:
_logger.warning("Cannot compute `max_steps` from the number of epochs as the train dataloader is not set")
return -1
# The number of training step per epoch is typically the number of
# global batches in the training set...
num_global_batches = len(self.datamodule._train_dl)
steps_per_epoch = num_global_batches
# ... unless it is constrained by the `limit_train_batches` option.
limit_batches = self._trainer.limit_train_batches
if limit_batches is not None:
if isinstance(limit_batches, float):
limit_batches = int(limit_batches * num_global_batches)
steps_per_epoch = min(num_global_batches, limit_batches)
return steps_per_epoch * self._trainer.max_epochs