run_train.py (281 lines of code) (raw):
"""
Nanotron training script.
Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml
```
"""
import argparse
import time
from pprint import pformat
from typing import Dict, Optional, cast
import nanotron.distributed as dist
from nanotron import logging
from nanotron.config import (
DataArgs,
DatasetStageArgs,
NanosetDatasetsArgs,
PretrainDatasetsArgs,
Qwen2Config,
SFTDatasetsArgs,
)
from nanotron.data.dataloader import (
dummy_infinite_data_generator,
get_train_dataloader,
)
from nanotron.data.processing import (
clm_process,
get_datasets,
)
from nanotron.data.sft_processing import prepare_sft_dataset
from nanotron.helpers import (
compute_remain_train_steps_of_a_data_stage_from_ckp,
get_consumed_train_samples_of_a_data_stage_from_ckp,
)
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.sanity_checks import sanity_check_dataloader
from nanotron.trainer import DistributedTrainer
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader
from nanotron.trainer import DataStageMetadata
from collections import defaultdict
try:
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoTokenizer
from transformers import __version__ as tf_version
except ImportError:
hf_hub_version = None
tf_version = None
logger = logging.get_logger(__name__)
# import lovely_tensors as lt
# lt.monkey_patch()
def get_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
consumed_train_samples_stage: int,
consumed_tokens_per_dataset_folder: Dict[str, int],
last_stages_consumed_tokens_per_dataset_folder: Dict[str, int],
num_remaining_train_steps: int,
sanity_check_dataloader_interval: Optional[int] = None,
):
"""
Returns a dataloader for a given data stage.
data: The data configuration for the current stage.
consumed_train_samples_stage: The number of samples consumed by the model in the this stage (each stage starts from zero).
consumed_tokens_per_dataset_folder: The number of tokens consumed by the model in previous stages to avoid reseeing them, because the sampler has restarted for this stage.
num_remaining_train_steps: The number of remaining training steps for this stage.
"""
assert consumed_train_samples_stage >= 0, "consumed_train_samples_stage should be greater than 0"
assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0"
# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
# Case 1: Dummy data generator
if data.dataset is None:
log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
dataloader = dummy_infinite_data_generator(
micro_batch_size=trainer.micro_batch_size,
sequence_length=trainer.sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
vocab_size=trainer.model_config.vocab_size,
seed=data.seed,
parallel_context=trainer.parallel_context,
use_position_ids=isinstance(
trainer.model_config, Qwen2Config
), # Simulate packed sequences to test SFT or inference
cp_pg=trainer.parallel_context.cp_pg,
)()
# Case 2: HuggingFace datasets
elif isinstance(data.dataset, PretrainDatasetsArgs) or isinstance(data.dataset, SFTDatasetsArgs):
log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
log_rank(
f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
logger=logger,
level=logging.INFO,
rank=0,
)
# We need to the 1st device to process dataset and cache it, then other devices load from cache
with main_rank_first(trainer.parallel_context.world_pg):
# TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
# TODO: generalise to include for validation/test splits
# We load the raw dataset
raw_dataset = get_datasets(
hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets,
hf_dataset_config_name=data.dataset.hf_dataset_config_name,
splits=data.dataset.hf_dataset_splits,
)["train"]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.padding_side = "left"
sequence_sep_tokens = [tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token, tokenizer.unk_token]
# assert bos or eos are present
assert (
tokenizer.bos_token is not None or tokenizer.eos_token is not None
), f"Tokenizer must have either bos or eos token, but found none for {tokenizer_path}"
# Check that tokenizer's vocab size is smaller than the model's vocab size
assert (
tokenizer.vocab_size <= trainer.model_config.vocab_size
), f"Tokenizer's vocab size ({tokenizer.vocab_size}) is larger than the model's vocab size ({trainer.model_config.vocab_size})"
# Different processing for SFT vs pretraining
if isinstance(data.dataset, SFTDatasetsArgs):
# For SFT, use the dedicated prepare_sft_dataset function
# Get optional debug parameter to limit dataset size (for faster development)
debug_max_samples = getattr(data.dataset, "debug_max_samples", None)
# Process the dataset using our dedicated SFT processing module
train_dataset = prepare_sft_dataset(
raw_dataset=raw_dataset,
tokenizer=tokenizer,
trainer_sequence_length=trainer.sequence_length,
debug_max_samples=debug_max_samples,
num_proc=data.dataset.dataset_processing_num_proc_per_process,
)
else:
# For pretraining, use existing CLM processing
train_dataset = clm_process(
raw_dataset=raw_dataset,
tokenizer=tokenizer,
text_column_name=data.dataset.text_column_name,
dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process,
dataset_overwrite_cache=data.dataset.dataset_overwrite_cache,
sequence_length=trainer.sequence_length,
)
# We load the processed dataset on the ranks requiring it
dataloader = get_train_dataloader(
train_dataset=train_dataset,
sequence_length=trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples_stage=consumed_train_samples_stage,
dataloader_num_workers=data.num_loading_workers,
seed_worker=data.seed,
dataloader_drop_last=True,
use_position_ids=isinstance(trainer.model_config, Qwen2Config),
sequence_sep_tokens=sequence_sep_tokens, # Used to generate position ids
)
# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)
# Case 3: Nanosets
elif isinstance(data.dataset, NanosetDatasetsArgs):
log_rank("Using TokenizedBytes Dataloader", logger=logger, level=logging.INFO, rank=0)
from nanotron.data.tokenized_bytes import get_tb_dataloader, get_tb_datasets
tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
assert (
len(tokenizer) == trainer.model_config.vocab_size
), f"Tokenizer vocab size ({len(tokenizer)}) does not match model config vocab size ({trainer.model_config.vocab_size}). "
log_rank(
f"[TokenizedBytes] Creating TokenizedBytes with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples",
logger=logger,
level=logging.INFO,
rank=0,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
start_time = time.time()
train_dataset, data_log = get_tb_datasets(
config=data.dataset,
global_batch_size=trainer.global_batch_size,
sequence_length=trainer.sequence_length,
train_steps=trainer.config.tokens.train_steps,
current_iteration=trainer.iteration_step,
parallel_context=trainer.parallel_context,
shuffle=data.dataset.shuffle_files,
eos_token_id=tokenizer.eos_token_id,
seed=data.seed,
consumed_samples=consumed_train_samples_stage,
consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder,
last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder,
)
dataloader = get_tb_dataloader(
dataset=train_dataset,
sequence_length=trainer.sequence_length,
micro_batch_size=trainer.micro_batch_size,
global_batch_size=trainer.global_batch_size,
num_workers=data.num_loading_workers,
cfg=data.dataset,
consumed_samples=consumed_train_samples_stage,
num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, # TODO: this overshoots what's needed by the current stage, but it doesnt matter?
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
dataloader_drop_last=True,
dataloader_pin_memory=True,
use_position_ids=isinstance(trainer.model_config, Qwen2Config),
use_doc_masking=getattr(trainer.model_config, "_use_doc_masking", None),
)
log_rank(
f"[TokenizedBytes] Time taken to create TokenizedBytes: {time.strftime('%M:%S', time.gmtime(time.time() - start_time))} (MM:SS)",
logger=logger,
level=logging.INFO,
rank=0,
)
dist.barrier()
# Create Nanoset
# from nanotron.data.nanoset import Nanoset
# with main_rank_first(trainer.parallel_context.world_pg):
# tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# eos_token_id = tokenizer.eos_token_id
# assert (
# eos_token_id is not None or data.dataset.return_positions is False
# ), "Tokenizer must have an eos token if return_positions is True"
# log_rank(
# f"[Nanoset] Creating Nanoset with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples",
# logger=logger,
# level=logging.INFO,
# rank=0,
# )
# start_time = time.time()
# train_dataset = Nanoset(
# dataset_folders=data.dataset.dataset_folder,
# sequence_length=trainer.sequence_length,
# token_size=data.dataset.token_size_in_bytes,
# train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
# dataset_weights=data.dataset.dataset_weights,
# random_seed=data.seed,
# return_positions=data.dataset.return_positions,
# eos_token_id=eos_token_id,
# )
# end_time = time.time()
# log_rank(
# f"[Nanoset] Time taken to create Nanoset: {time.strftime('%M:%S', time.gmtime(end_time - start_time))} (MM:SS)",
# logger=logger,
# level=logging.INFO,
# rank=0,
# )
# # Prepare dataloader
# train_dataloader = build_nanoset_dataloader(
# train_dataset,
# trainer.sequence_length,
# parallel_context=trainer.parallel_context,
# input_pp_rank=input_pp_rank,
# output_pp_rank=output_pp_rank,
# micro_batch_size=trainer.micro_batch_size,
# consumed_train_samples=consumed_train_samples,
# dataloader_num_workers=data.num_loading_workers,
# dataloader_drop_last=True,
# use_position_ids=isinstance(trainer.model_config, Qwen2Config),
# use_doc_masking=False,
# dataloader_pin_memory=True,
# )
# dist.barrier()
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
if sanity_check_dataloader_interval is not None:
sanity_check_dataloader(
dataloader,
tokenizer_path=trainer.config.tokenizer.tokenizer_name_or_path,
sanity_check_dataloader_interval=sanity_check_dataloader_interval,
)
return dataloader
def get_dataloader(
trainer: DistributedTrainer, sanity_check_dataloader_interval: Optional[int] = None
) -> Dict[str, DataLoader]:
dataloaders = {}
# Print training plan
log_rank("Training plan", logger=logger, level=logging.INFO, rank=0, is_separator=True)
stages_info = "".join(
f"[Stage {stage.name}] start from step {stage.start_training_step} \n" for stage in trainer.config.data_stages
)
full_log_message = f"There are {len(trainer.config.data_stages)} training stages \n{stages_info}"
log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0)
current_stage = None
# WARNING: we assume we train on last stage
stage_idx = len(trainer.config.data_stages) - 1
stage_args = trainer.config.data_stages[stage_idx]
if trainer.iteration_step+1 == stage_args.start_training_step:
log_rank(f"Starting new stage {stage_args.name}", logger=logger, level=logging.INFO, rank=0)
# we start a new stage
if stage_idx >= len(trainer.metadata.data_stages):
trainer.metadata.data_stages.append(DataStageMetadata(
name=stage_args.name,
start_training_step=stage_args.start_training_step,
consumed_train_samples=0,
consumed_tokens_per_dataset_folder={},
sequence_length=trainer.sequence_length,
))
elif len(trainer.metadata.data_stages) < len(trainer.config.data_stages):
raise ValueError(f"If you're trying to start a new stage, you need to set `start_training_step` to the step after the last stage's: {trainer.iteration_step+1}")
current_stage = trainer.metadata.data_stages[stage_idx]
cur_stage_consumed_train_samples = current_stage.consumed_train_samples
consumed_tokens_per_dataset_folder = current_stage.consumed_tokens_per_dataset_folder
stage_args_data = trainer.config.data_stages[stage_idx].data
num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
current_stage, trainer.config, trainer.metadata
) # TODO: check this
log_rank(
f"Current stage: {current_stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {cur_stage_consumed_train_samples} samples"
f"Consumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}",
logger=logger,
level=logging.INFO,
rank=0,
)
# warn that if seqlen of stage - 1 has changed, consumed_train_samples=0 so we'll assume we're reading from new folder (so that we can resume training)
if current_stage.sequence_length != trainer.metadata.data_stages[-1].sequence_length:
raise NotImplementedError("We don't support changing sequence length between stages yet")
if current_stage.consumed_train_samples == 0:
log_rank(
f"Warning: The sequence length of the last stage has changed from {trainer.metadata.data_stages[-1].sequence_length} to {current_stage.sequence_length}. We'll assume we're reading from the beginning of the dataset folders.",
logger=logger,
level=logging.WARNING,
rank=0,
)
else:
# we're resuming training, so that's fine
pass
cur_stage_consumed_train_samples = current_stage.consumed_train_samples
else:
# Prepare last_stages_consumed_tokens_per_dataset_folder which will be used to offset BlendableDataset to avoid reseeing consumed tokens even when sampler has restarted for this stage
last_stages_consumed_tokens_per_dataset_folder = {}
for stage in trainer.metadata.data_stages[:-1]:
for folder_path, consumed_tokens in stage.consumed_tokens_per_dataset_folder.items():
last_stages_consumed_tokens_per_dataset_folder[folder_path] = last_stages_consumed_tokens_per_dataset_folder.get(folder_path, 0) + consumed_tokens
dataloaders[current_stage.name] = get_dataloader_from_data_stage(
trainer,
stage_args_data,
consumed_train_samples_stage=cur_stage_consumed_train_samples,
consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder,
last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder,
num_remaining_train_steps=num_remaining_train_steps,
sanity_check_dataloader_interval=sanity_check_dataloader_interval,
)
return dataloaders
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
parser.add_argument(
"--sanity-check-dataloader-interval",
type=int,
default=None,
help="Optional interval to print dataloader samples",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
config_file = args.config_file
# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer, args.sanity_check_dataloader_interval)
# Train
trainer.train(dataloader)