in src/nanotron/trainer.py [0:0]
def pre_training(self, *args, **kwargs):
if not self.config.general.ignore_sanity_checks:
log_rank(
"Sanity checks are enabled, this will slow down the training. To disable them, set `config.general.ignore_sanity_checks` to `True`",
logger=logger,
level=logging.WARNING,
rank=0,
)
assert (
os.environ.get("NANOTRON_BENCHMARK", "0") != "1"
), "Sanity checks are enabled while you're running a benchmark. Make sure to disable them by setting `config.general.ignore_sanity_checks` to `True`"
metadata: TrainingMetadata = self.metadata
log_rank("Start training", logger=logger, level=logging.INFO, rank=0, is_separator=True)
log_rank(
f"mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | cp: {self.parallel_context.cp_pg.size()} | sequence_length: {self.sequence_length} | global_batch_size: {self.global_batch_size} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_tokens_total: {metadata.consumed_tokens_total}", # noqa
logger=logger,
level=logging.INFO,
rank=0,
)
current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
# Initialize wandb for each TP group if TP > 1, but only for dp=0 ranks
if wandb is not None:
tp_size = self.parallel_context.tp_pg.size()
dp_cp_rank = dist.get_rank(self.parallel_context.dp_cp_pg)
tp_rank = dist.get_rank(self.parallel_context.tp_pg)
world_rank = dist.get_rank(self.parallel_context.world_pg)
if tp_size > 1 and self.metrics_logging.log_level > 0:
# Create one wandb logger per TP group for DP=0 ranks
if dp_cp_rank == 0:
# Create a run name that includes the TP group
run_name = f"{current_time}_{self.config.general.run}_tp_group_{tp_rank}"
wandb.init(
project=self.config.general.project,
name=run_name,
config={"nanotron_config": self.config.as_dict()},
)
log_rank(
f"Initialized wandb run '{run_name}' for TP rank {tp_rank}",
logger=logger,
level=logging.INFO,
rank=world_rank,
)
elif world_rank == self.logger_ranks[0]:
run_name = f"{current_time}_{self.config.general.run}"
x_stats_sampling_interval = os.environ.get("STATS_SAMPLING_INTERVAL_IN_SEC", None)
wandb_settings = {}
if x_stats_sampling_interval is not None:
wandb_settings["x_stats_sampling_interval"] = float(x_stats_sampling_interval)
wandb_settings["x_stats_open_metrics_endpoints"] = {
"dcgm": "http://localhost:9104/metrics",
"node": "http://localhost:9100/metrics",
"lustre": "http://localhost:9106/metrics",
"gpu": "http://26.0.168.238:9103/metrics",
"efa": "http://localhost:9101/metrics",
}
wandb_settings["x_stats_open_metrics_filters"] = [
"DCGM_FI_",
"node_",
"lustre_",
"nvidia_gpu_",
"efa_",
]
wandb.init(
project=self.config.general.project,
name=run_name,
config={"nanotron_config": self.config.as_dict()},
settings=wandb.Settings(**wandb_settings),
)
# save config file
temp_config_path = tempfile.mktemp(suffix=".yaml", prefix="config")
self.config.save_as_yaml(temp_config_path)
wandb.save(temp_config_path, base_path=os.path.dirname(temp_config_path), policy="now")
log_rank(
f"Initialized wandb run '{run_name}' for TP rank {tp_rank}",
logger=logger,
level=logging.INFO,
rank=world_rank,
)