def pre_training()

in src/nanotron/trainer.py [0:0]


    def pre_training(self, *args, **kwargs):
        if not self.config.general.ignore_sanity_checks:
            log_rank(
                "Sanity checks are enabled, this will slow down the training. To disable them, set `config.general.ignore_sanity_checks` to `True`",
                logger=logger,
                level=logging.WARNING,
                rank=0,
            )
            assert (
                os.environ.get("NANOTRON_BENCHMARK", "0") != "1"
            ), "Sanity checks are enabled while you're running a benchmark. Make sure to disable them by setting `config.general.ignore_sanity_checks` to `True`"

        metadata: TrainingMetadata = self.metadata

        log_rank("Start training", logger=logger, level=logging.INFO, rank=0, is_separator=True)
        log_rank(
            f"mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | cp: {self.parallel_context.cp_pg.size()} | sequence_length: {self.sequence_length} | global_batch_size: {self.global_batch_size} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_tokens_total: {metadata.consumed_tokens_total}",  # noqa
            logger=logger,
            level=logging.INFO,
            rank=0,
        )

        current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")

        # Initialize wandb for each TP group if TP > 1, but only for dp=0 ranks
        if wandb is not None:
            tp_size = self.parallel_context.tp_pg.size()
            dp_cp_rank = dist.get_rank(self.parallel_context.dp_cp_pg)
            tp_rank = dist.get_rank(self.parallel_context.tp_pg)
            world_rank = dist.get_rank(self.parallel_context.world_pg)

            if tp_size > 1 and self.metrics_logging.log_level > 0:
                # Create one wandb logger per TP group for DP=0 ranks
                if dp_cp_rank == 0:
                    # Create a run name that includes the TP group
                    run_name = f"{current_time}_{self.config.general.run}_tp_group_{tp_rank}"

                    wandb.init(
                        project=self.config.general.project,
                        name=run_name,
                        config={"nanotron_config": self.config.as_dict()},
                    )
                    log_rank(
                        f"Initialized wandb run '{run_name}' for TP rank {tp_rank}",
                        logger=logger,
                        level=logging.INFO,
                        rank=world_rank,
                    )
            elif world_rank == self.logger_ranks[0]:
                run_name = f"{current_time}_{self.config.general.run}"
                x_stats_sampling_interval = os.environ.get("STATS_SAMPLING_INTERVAL_IN_SEC", None)

                wandb_settings = {}

                if x_stats_sampling_interval is not None:
                    wandb_settings["x_stats_sampling_interval"] = float(x_stats_sampling_interval)
                    wandb_settings["x_stats_open_metrics_endpoints"] = {
                        "dcgm": "http://localhost:9104/metrics",
                        "node": "http://localhost:9100/metrics",
                        "lustre": "http://localhost:9106/metrics",
                        "gpu": "http://26.0.168.238:9103/metrics",
                        "efa": "http://localhost:9101/metrics",
                    }
                    wandb_settings["x_stats_open_metrics_filters"] = [
                        "DCGM_FI_",
                        "node_",
                        "lustre_",
                        "nvidia_gpu_",
                        "efa_",
                    ]

                wandb.init(
                    project=self.config.general.project,
                    name=run_name,
                    config={"nanotron_config": self.config.as_dict()},
                    settings=wandb.Settings(**wandb_settings),
                )
                # save config file
                temp_config_path = tempfile.mktemp(suffix=".yaml", prefix="config")
                self.config.save_as_yaml(temp_config_path)
                wandb.save(temp_config_path, base_path=os.path.dirname(temp_config_path), policy="now")
                log_rank(
                    f"Initialized wandb run '{run_name}' for TP rank {tp_rank}",
                    logger=logger,
                    level=logging.INFO,
                    rank=world_rank,
                )