def exp_manager()

in src/hyperpod_nemo_adapter/utils/exp_manager.py [0:0]


def exp_manager(trainer: "pytorch_lightning.Trainer", cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]:
    """
    exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm
    of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir,
    name, and version from the logger. Otherwise it will use the exp_dir and name arguments to create the logging
    directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

    The version can be a datetime string or an integer. Datestime version can be disabled if use_datetime_version is set
    to False. It optionally creates TensorBoardLogger, WandBLogger, DLLogger, MLFlowLogger, ClearMLLogger,
    ModelCheckpoint objects from pytorch lightning.
    It copies sys.argv, and git information if available to the logging directory. It creates a log file for each
    process to log their output into.

    exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from
    the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need
    multiple consecutive jobs), you need to avoid creating the version folders. Therefore from v1.0.0, when
    resume_if_exists is set to True, creating the version folders is ignored.
    """

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = trainer.node_rank * trainer.num_devices + local_rank
    logging.rank = global_rank

    if cfg is None:
        logging.error("exp_manager did not receive a cfg argument. It will be disabled.")
        return
    if trainer.fast_dev_run:
        logging.info("Trainer was called with fast_dev_run. exp_manager will return without any functionality.")
        return

    # Ensure passed cfg is compliant with ExpManagerConfig
    schema = OmegaConf.structured(ExpManagerConfig)
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    elif not isinstance(cfg, DictConfig):
        raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig")
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
    cfg = OmegaConf.merge(schema, cfg)

    log_dir, exp_dir, name, version = get_log_dir(
        trainer=trainer,
        exp_dir=cfg.exp_dir,
        name=cfg.name,
        version=cfg.version,
        explicit_log_dir=cfg.explicit_log_dir,
        use_datetime_version=cfg.use_datetime_version,
        resume_if_exists=cfg.resume_if_exists,
    )

    checkpoint_name = name
    # If name returned from get_log_dir is "", use cfg.name for checkpointing
    if checkpoint_name is None or checkpoint_name == "":
        checkpoint_name = cfg.name or "default"

    # by default we use the exp_dir as the logger_exp_dir (used by nemo when setting up tensorboard logger)
    logger_exp_dir = exp_dir

    # tensorboard kwargs
    if cfg.create_tensorboard_logger:
        if cfg.summary_writer_kwargs.get("save_dir", None):
            # want to make sure there is a save_dir in summary_writer_kwargs
            logger_exp_dir = cfg.summary_writer_kwargs.get("save_dir")
            logging.warning(
                "tensorboard logger specified to %s. Overriding exp_dir: %s",
                cfg.summary_writer_kwargs.get("save_dir"),
                cfg.exp_dir,
            )
            # https://github.com/NVIDIA/NeMo/blob/58edb40a74f2f6589ed2e9c8c1d5c58fb2eefd63/nemo/utils/exp_manager.py#L1007
            # removing save_dir arg since NeMo passes exp_dir param directly into TensorboardLogger() class so including in kwargs as well would unexpectedly overload class init
            cfg.summary_writer_kwargs.pop("save_dir", None)

        if cfg.summary_writer_kwargs.get("name", None):
            # want to make sure there is a name in summary_writer_kwargs
            cfg.summary_writer_kwargs.name = cfg.name
            logging.warning(
                "tensorboard logger specified but no 'name' set. Using experiment name in recipe: %s",
                cfg.summary_writer_kwargs.name,
            )

    # mlflow kwargs
    if cfg.create_mlflow_logger:
        if not cfg.mlflow_logger_kwargs.get("experiment_name", None):
            # want to make sure there is an experiment_name in mlflow_logger_kwargs
            cfg.mlflow_logger_kwargs.experiment_name = cfg.name
            logging.warning(
                "mlflow logger specified but no 'experiment_name' set. Using experiment name in recipe: %s",
                cfg.mlflow_logger_kwargs.experiment_name,
            )

        if not cfg.mlflow_logger_kwargs.get("tracking_uri", None):
            # want to make sure there is a tracking_uri in mlflow_logger_kwargs
            cfg.mlflow_logger_kwargs.tracking_uri = cfg.exp_dir
            logging.warning(
                "mlflow logger specified but no 'tracking_uri' set. Using experiment name in recipe: %s",
                cfg.mlflow_logger_kwargs.tracking_uri,
            )

    # wandb kwargs
    if cfg.create_wandb_logger:
        if not cfg.wandb_logger_kwargs.get("name", None):
            # want to make sure there is a name in wandb_logger_kwargs
            cfg.wandb_logger_kwargs.name = cfg.name
            logging.warning(
                "wandb logger specified but no 'name' set. Using experiment name in recipe: %s",
                cfg.mlflow_logger_kwargs.experiment_name,
            )

        if not cfg.wandb_logger_kwargs.get("save_dir", None):
            # want to make sure there is a save_dir in wandb_logger_kwargs
            cfg.wandb_logger_kwargs.save_dir = cfg.exp_dir
            logging.warning(
                "wandb logger specified but no 'save_dir' set. Using experiment name in recipe: %s",
                cfg.mlflow_logger_kwargs.experiment_name,
            )

    cfg.name = name  # Used for configure_loggers so that the log_dir is properly set even if name is ""
    cfg.version = version

    # update app_state with log_dir, exp_dir, etc
    app_state = SageMakerAppState()
    app_state.log_dir = log_dir
    app_state.exp_dir = exp_dir
    app_state.name = name
    app_state.version = version
    app_state.checkpoint_name = checkpoint_name
    app_state.create_checkpoint_callback = cfg.create_checkpoint_callback
    app_state.checkpoint_callback_params = cfg.checkpoint_callback_params

    # Create the logging directory if it does not exist
    os.makedirs(log_dir, exist_ok=True)  # Cannot limit creation to global zero as all ranks write to own log file
    logging.info(f"Experiments will be logged at {log_dir}")
    trainer._default_root_dir = log_dir

    if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True:
        raise ValueError(
            f"Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither."
        )

    # Handle logging to file, overriding log dir and remove Nemo Testing flags
    log_file = log_dir / f"sagemaker_log_globalrank-{global_rank}_localrank-{local_rank}.txt"
    if cfg.log_local_rank_0_only is True:
        if local_rank == 0:
            logging.add_file_handler(log_file)
    elif cfg.log_global_rank_0_only is True:
        if global_rank == 0:
            logging.add_file_handler(log_file)
    else:
        # Logs on all ranks.
        logging.add_file_handler(log_file)

    # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks
    # not just global rank 0.
    if (
        cfg.create_tensorboard_logger
        or cfg.create_wandb_logger
        or cfg.create_mlflow_logger
        or cfg.create_dllogger_logger
        or cfg.create_clearml_logger
        or cfg.create_neptune_logger
    ):
        configure_loggers(
            trainer,
            logger_exp_dir,
            log_dir,
            cfg.name,
            cfg.version,
            cfg.checkpoint_callback_params,
            cfg.create_tensorboard_logger,
            cfg.summary_writer_kwargs,
            cfg.create_wandb_logger,
            cfg.wandb_logger_kwargs,
            cfg.create_mlflow_logger,
            cfg.mlflow_logger_kwargs,
            cfg.create_dllogger_logger,
            cfg.dllogger_logger_kwargs,
            cfg.create_clearml_logger,
            cfg.clearml_logger_kwargs,
            cfg.create_neptune_logger,
            cfg.neptune_logger_kwargs,
        )

    # add loggers timing callbacks
    if cfg.log_step_timing:
        timing_callback = TimingCallback(timer_kwargs=cfg.step_timing_kwargs or {})
        trainer.callbacks.insert(0, timing_callback)

    if cfg.ema.enable:
        ema_callback = EMA(
            decay=cfg.ema.decay,
            validate_original_weights=cfg.ema.validate_original_weights,
            cpu_offload=cfg.ema.cpu_offload,
            every_n_steps=cfg.ema.every_n_steps,
        )
        trainer.callbacks.append(ema_callback)

    if cfg.create_early_stopping_callback:
        early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params)
        trainer.callbacks.append(early_stop_callback)

    if cfg.disable_validation_on_resume:
        # extend training loop to skip initial validation when resuming from checkpoint
        configure_no_restart_validation_training_loop(trainer)
    # Setup a stateless timer for use on clusters.
    if cfg.max_time_per_run is not None:
        found_ptl_timer = False
        for idx, callback in enumerate(trainer.callbacks):
            if isinstance(callback, Timer):
                # NOTE: PTL does not expose a `trainer.max_time`. By the time we are in this function, PTL has already setup a timer if the user specifies `trainer.max_time` so best we can do is replace that.
                # Working: If only `trainer.max_time` is set - it behaves as a normal PTL timer. If only `exp_manager.max_time_per_run` is set - it behaves as a StateLessTimer. If both are set, it also behaves as a StateLessTimer.
                logging.warning(
                    f"Found a PTL Timer callback, replacing with a StatelessTimer callback. This will happen if you set trainer.max_time as well as exp_manager.max_time_per_run."
                )
                trainer.callbacks[idx] = StatelessTimer(cfg.max_time_per_run)
                found_ptl_timer = True
                break

        if not found_ptl_timer:
            trainer.max_time = cfg.max_time_per_run
            trainer.callbacks.append(StatelessTimer(cfg.max_time_per_run))

    if is_global_rank_zero():
        # Move files_to_copy to folder and add git information if present
        if cfg.files_to_copy:
            for _file in cfg.files_to_copy:
                copy(Path(_file), log_dir)

        # Create files for cmd args and git info
        with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file:
            _file.write(" ".join(sys.argv))

        # Try to get git hash
        git_repo, git_hash = get_git_hash()
        if git_repo:
            with open(log_dir / "git-info.log", "w", encoding="utf-8") as _file:
                _file.write(f"commit hash: {git_hash}")
                _file.write(get_git_diff())

        # Add err_file logging to global_rank zero
        logging.add_err_file_handler(log_dir / "nemo_error_log.txt")

        # Add lightning file logging to global_rank zero
        add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt", log_dir / "nemo_error_log.txt")

    elif trainer.num_nodes * trainer.num_devices > 1:
        # sleep other ranks so rank 0 can finish
        # doing the initialization such as moving files
        time.sleep(cfg.seconds_to_sleep)

    return log_dir