in src/nanotron/config/config.py [0:0]
def __post_init__(self):
if self.s3_upload is not None:
self.s3_upload.__post_init__()
if self.lighteval is not None:
if self.lighteval.eval_interval is None:
self.lighteval.eval_interval = self.checkpoints.checkpoint_interval
else:
assert (
self.lighteval.eval_interval % self.checkpoints.checkpoint_interval == 0
), f"eval_interval={self.lighteval.eval_interval} must be a multiple of checkpoint_interval={self.checkpoints.checkpoint_interval}"
# Some final sanity checks across separate arguments sections:
if self.profiler is not None and self.profiler.profiler_export_path is not None:
total_profiling_steps = self.profiler.skip_first + self.profiler.repeat * (
self.profiler.wait + self.profiler.warmup + self.profiler.active
)
assert (
self.tokens.train_steps >= total_profiling_steps
), f"Profiling steps ({total_profiling_steps}) must be less than or equal to train steps ({self.tokens.train_steps})"
if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None:
self.optimizer.learning_rate_scheduler.lr_decay_steps = (
self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps
)
if self.data_stages is not None:
self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step)
names = [stage.name for stage in self.data_stages]
training_steps = [stage.start_training_step for stage in self.data_stages]
assert any(
stage.start_training_step == 1 for stage in self.data_stages
), "You must have a training stage starting at 1 in the config's data_stages"
for stage in self.data_stages:
if names.count(stage.name) > 1:
raise ValueError(f"Each stage should have unique names and not {names}")
if training_steps.count(stage.start_training_step) > 1:
raise ValueError(
f"Each stage should have unique starting training step, please change the starting training step for stage {stage.name}"
)
if isinstance(stage.data.dataset, NanosetDatasetsArgs):
if self.model.model_config.vocab_size == -1:
self.model.model_config.vocab_size = stage.data.dataset.vocab_size
logger.warning(
f"Setting model's vocab_size to {self.model.model_config.vocab_size} from dataset's vocab_size ({stage.data.dataset.vocab_size})"
)
assert (
self.model.model_config.vocab_size == stage.data.dataset.vocab_size
), f"Model's vocab_size ({self.model.model_config.vocab_size}) does not match dataset's ({stage.data.dataset.dataset_folder}) vocab_size ({stage.data.dataset.vocab_size})"
if self.tokenizer is None:
self.tokenizer = TokenizerArgs(tokenizer_name_or_path=stage.data.dataset.tokenizer_name)
logger.warning(
f"Setting tokenizer to {self.tokenizer.tokenizer_name_or_path} from dataset's tokenizer ({stage.data.dataset.tokenizer_name})"
)
assert (
self.tokenizer.tokenizer_name_or_path == stage.data.dataset.tokenizer_name
), f"Tokenizer passed in config ({self.tokenizer.tokenizer_name_or_path}) does not match dataset's ({stage.data.dataset.dataset_folder}) tokenizer ({stage.data.dataset.tokenizer_name})"
# NOTE: must order the stages by start_training_step from lowest to highest
assert all(
self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"
# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None
# Model verifications
assert (
self.model.model_config.num_attention_heads % self.parallelism.tp == 0
), f"num_attention_heads ({self.model.model_config.num_attention_heads}) must be divisible by tp ({self.parallelism.tp})"
assert (
self.model.model_config.num_attention_heads >= self.model.model_config.num_key_value_heads
), f"num_attention_heads ({self.model.model_config.num_attention_heads}) must be >= num_key_value_heads ({self.model.model_config.num_key_value_heads})"
assert (
self.model.model_config.num_key_value_heads >= self.parallelism.tp
), f"num_key_value_heads ({self.model.model_config.num_key_value_heads}) must be >= tp ({self.parallelism.tp})" # TODO: remove this once we ensure KV heads get duplicated correctly
assert (
self.model.model_config.num_attention_heads % self.model.model_config.num_key_value_heads == 0
), f"num_attention_heads ({self.model.model_config.num_attention_heads}) must be divisible by num_key_value_heads ({self.model.model_config.num_key_value_heads})"
# data_stages
if self.data_stages is not None:
for stage in self.data_stages:
if stage.sequence_length is None:
stage.sequence_length = self.tokens.sequence_length