ultravox/training/train.py (277 lines of code) (raw):
import contextlib
import dataclasses
import datetime
import glob
import logging
import os
import random
import traceback
from typing import Dict, List
import accelerate
import datasets as hf_datasets
import safetensors.torch
import torch
import torch.distributed
import transformers
import wandb
import wandb.sdk
from ultravox import data as datasets
from ultravox.evaluation import eval
from ultravox.inference import infer
from ultravox.model import file_utils
from ultravox.training import config_base
from ultravox.training import ddp_utils
from ultravox.training import model_types
from ultravox.training.helpers import prefetch_weights
from ultravox.utils import device_helpers
from ultravox.utils import monkey_patches
def patch_trainer_save_fsdp_model():
"""
When using FSDP, the trainer._save_checkpoint first calls self.save_model and then accelerator.save_fsdp_model.
This leads to the model being saved twice when in FULL_STATE_DICT mode as save_model refuses to save the model otherwise.
To make matters worse, the second save is going to be produce a huge `pytorch_model_fsdp.bin` file which is not what we want.
This function skips the second save if the state_dict_type is FULL_STATE_DICT.
We currently only use FULL_STATE_DICT (default) for training checkpoints.
"""
def save_fsdp_model_if_not_full_state_dict(
fsdp_plugin, accelerator, model, output_dir, **kwargs
):
if "FULL_STATE_DICT" in str(fsdp_plugin.state_dict_type):
return
original_save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, **kwargs)
original_save_fsdp_model = transformers.trainer.save_fsdp_model
transformers.trainer.save_fsdp_model = save_fsdp_model_if_not_full_state_dict
def prepare_dataset(
train_args: config_base.TrainConfig,
model_pack: model_types.ModelPack,
data_opts: List[datasets.DatasetOptions],
data_args: datasets.VoiceDatasetArgs,
verbose: bool = False,
) -> datasets.SizedIterableDataset:
data_names = [ds.name for ds in data_opts]
data_weights = [ds.weight for ds in data_opts]
data_sets = [
datasets.create_dataset(ds, data_args, verbose=verbose) for ds in data_names
]
# If we're using epochs to train, validate the dataset length is appropriate.
if train_args.max_steps == 0:
for ds in data_sets:
assert (
len(ds) > 1
), f"Dataset {ds} has length {len(ds)} which is too short for epoch training"
interleave = datasets.InterleaveDataset(data_sets, data_weights)
ds_with_proc = model_pack.wrap_with_data_proc(interleave)
if data_args.max_samples:
return datasets.Range(ds_with_proc, data_args.max_samples)
else:
return ds_with_proc
def main() -> None:
monkey_patches.apply_all_patches()
# Disable parallelism to avoid deadlocks in DataLoader, apparently
# multiple processes are forked when using multiple datasets.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Log model checkpoints to W&B: we can reduce to model if storage is an issue
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_PROJECT"] = "ultravox"
config = config_base.get_train_config()
patch_trainer_save_fsdp_model()
transformers.set_seed(config.seed)
train(config)
def train(config: config_base.TrainConfig):
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_master = local_rank == 0
is_distributed = world_size > 1
# DDP blows up logging, so this is an attempt to suppress it to only logs from the master process
logging.basicConfig(level=logging.INFO if is_master else logging.ERROR)
# os.environ["TORCH_LOGS"] = "ERROR" if is_master else "WARNING"
transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)
hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)
if is_distributed:
torch.distributed.init_process_group(backend="nccl")
with ddp_utils.run_on_master_first(is_master):
# For larger models, we assume that the weights are already downloaded via prefetch_weights.py
# Otherwise the barrier call can timeout.
# This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run.
prefetch_weights.download_weights(
[config.text_model, config.audio_model], config.model_load_dir
)
logging.info("Instantiating model and processor...")
model_load_context = (
accelerate.init_empty_weights()
if config.use_fsdp and not is_master
else contextlib.nullcontext()
)
# If we're using FSDP, we can just initialize the model on the main process
# and use sync_model_states to distribute the weights to the other processes.
# Otherwise we'd be loading the model on every process, which uses too much CPU memory.
with model_load_context:
model_pack = model_types.create_model_pack(config)
model = model_pack.model
logging.info("Model and processor instantiated.")
# Starting W&B. HF Trainer can also do this, but this way we can include the config.
# Initializing sooner also means more of the stdout logs are captured by W&B.
if "wandb" in config.report_logs_to and is_master:
wandb.init(
project=os.getenv("WANDB_PROJECT", "ultravox"),
config=dataclasses.asdict(config),
name=config.exp_name,
dir="runs",
tags=config.run_tags,
save_code=True,
)
if config.model_load_dir:
logging.info(f"Loading model state dict from {config.model_load_dir}")
load_path = file_utils.download_dir_if_needed(config.model_load_dir)
if os.path.isdir(load_path):
load_path = os.path.join(load_path, "model*.safetensors")
paths = glob.glob(load_path)
assert len(paths) > 0, f"No model files found at {load_path}"
for path in paths:
state_dict = safetensors.torch.load_file(path)
mismatch = model.load_state_dict(state_dict, strict=False)
if mismatch.unexpected_keys:
raise ValueError(
f"Unexpected keys in state dict: {mismatch.unexpected_keys}"
)
if config.ignore_data_skip and config.resume_from_load_dir:
new_shuffle_seed = random.randint(1000, 1999)
logging.info(
"Since data skipping is ignored when resuming from a checkpoint,"
f" randomly setting the train dataset seed to {new_shuffle_seed}."
)
config.train_dataset_args.shuffle_seed = new_shuffle_seed
if wandb.run:
wandb.run.config.update(
{"train_dataset_args": dataclasses.asdict(config.train_dataset_args)},
allow_val_change=True,
)
model.print_trainable_parameters()
if not config.use_fsdp:
# Moving to device in FSDP is handled by the Trainer
model.to(device=torch.device(config.device, index=local_rank))
logging.info(f"Using device (world_size): {model.device} ({world_size})")
# Register custom datasets
datasets.register_datasets(config.get_data_sets())
# Prepare dataset, subsetting if needed
train_dataset: datasets.SizedIterableDataset
val_datasets: Dict[str, datasets.SizedIterableDataset] = {}
train_dataset = prepare_dataset(
train_args=config,
model_pack=model_pack,
data_opts=config.get_train_sets(),
data_args=config.train_dataset_args,
verbose=is_master,
)
if is_master:
for val_opt in config.get_val_sets():
val_dataset = prepare_dataset(
train_args=config,
model_pack=model_pack,
data_opts=[val_opt],
data_args=config.val_dataset_args,
verbose=is_master,
)
val_datasets[val_opt.name] = val_dataset
logging.info(
f"Loaded {len(config.train_sets)}) data sets, sample limit: {config.train_dataset_args.max_samples} (val sample limit: {config.val_dataset_args.max_samples})"
)
else:
# When using DDP with split_batches=True, the primary process will distribute the batches to the workers
# The point of this is to avoid unnecessary data processing/downloading in the workers.
# When using epochs to train, emptydataset must have a length equal to the training set
train_dataset = datasets.EmptyDataset(len(train_dataset))
for val_opts in config.get_val_sets():
val_datasets[val_opts.name] = datasets.EmptyDataset(
config.val_dataset_args.max_samples or 1
)
logging.info(f"Config Params: {config}")
trainer = transformers.Seq2SeqTrainer(
model,
train_dataset=train_dataset,
eval_dataset=val_datasets,
data_collator=model_pack.data_collator,
processing_class=model_pack.processor,
args=transformers.Seq2SeqTrainingArguments(
dataloader_num_workers=config.num_workers if is_master else 0,
output_dir=config.output_dir,
run_name=config.exp_name,
optim=config.optimizer,
num_train_epochs=config.num_epochs,
max_steps=config.max_steps,
eval_strategy="steps" if config.val_steps else "no",
eval_steps=config.val_steps,
save_strategy="steps" if config.save_steps else "no",
save_steps=config.save_steps,
logging_first_step=True,
logging_dir=config.logs_dir,
logging_steps=config.logging_steps,
# TODO (Farzad): reconsider for multi-node
# In DDP world_size is set to num_gpus and we want process-0 to split the batches
per_device_train_batch_size=config.batch_size * world_size,
accelerator_config={"split_batches": True},
gradient_accumulation_steps=config.grad_accum_steps,
eval_accumulation_steps=config.val_accum_steps,
# tf32=dtype == torch.float32 and device.type == "cuda", # TODO: check for Ampere GPU not just CUDA
ddp_find_unused_parameters=False,
learning_rate=config.lr,
lr_scheduler_type=config.lr_scheduler,
lr_scheduler_kwargs=config.lr_scheduler_kwargs,
warmup_steps=0 if config.lr_warmup_steps < 1 else config.lr_warmup_steps,
warmup_ratio=config.lr_warmup_steps if config.lr_warmup_steps < 1 else 0,
weight_decay=config.weight_decay,
# fp16=dtype == torch.float16,
# bf16=dtype == torch.bfloat16,
use_cpu=config.device == "cpu",
seed=config.seed + local_rank,
report_to=config.report_logs_to,
# torch_compile=True,
fsdp="full_shard auto_wrap" if config.use_fsdp else "",
fsdp_config={
"backward_prefetch": "backward_pre",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
),
)
caught_exception = None
if config.do_train:
# Training loop
logging.info("Starting training...")
t_start = datetime.datetime.now()
logging.info(f"train start time: {t_start}")
if config.val_steps:
if config.use_fsdp:
logging.warning(
"FSDP is enabled: Skipping initial validation since model is not initialized."
)
else:
trainer.evaluate()
try:
resume_from_checkpoint = load_path if config.resume_from_load_dir else None
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
except Exception as e:
logging.error(f"[rank: {local_rank}] Training failed with error: {e}")
logging.error(f"[rank: {local_rank}] {traceback.format_exc()}")
caught_exception = e
t_end = datetime.datetime.now()
logging.info(f"train end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")
# save_final_model(trainer, model_pack, config)
# use fixie-ai/evals for evaluation if in use_fsdp mode
if config.do_eval:
if config.model_type == "lsm":
logging.warning("Evaluation is not supported for LSM models, skipping")
if config.use_fsdp:
logging.warning("Evaluation is not supported in FSDP mode, skipping")
else:
logging.info("Starting evaluation...")
t_start = datetime.datetime.now()
logging.info(f"eval start time: {t_start}")
# Merge LoRA weights for better inference performance.
# Note: this is irreversible and changes model saving format
model.merge_and_unload()
# changing padding side to left for inference
model_pack.change_text_padding_side("left")
inference = infer.LocalInference(
model=model,
processor=model_pack.processor,
tokenizer=model_pack.get_text_tokenizer(),
device=(
f"{config.device}:{local_rank}" if world_size > 1 else config.device
),
dtype=device_helpers.get_dtype(config.data_type),
)
metrics, output_files = eval.eval_datasets(
inference,
config.get_eval_sets(),
config.eval_dataset_args,
config.eval_batch_size,
config.eval_max_tokens,
config.eval_temperature,
config.output_dir,
)
if is_master:
eval.print_results(metrics, output_files)
t_end = datetime.datetime.now()
logging.info(f"eval end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")
# finish wandb run if it exists
if wandb.run and is_master:
wandb.run.finish(exit_code=1 if caught_exception else 0)
# destroy process group if distributed training
if world_size > 1:
torch.distributed.destroy_process_group()
if caught_exception:
logging.error(
f"[rank: {local_rank}] Training failed earlier, exiting and raising error."
)
raise caught_exception
def save_final_model(
trainer: transformers.Trainer,
model_pack: model_types.ModelPack,
config: config_base.TrainConfig,
):
if config.use_fsdp:
# For training checkpoints, even if we decide to use SHARDED_STATE_DICT (which is faster),
# we still want the final save to be with FULL_STATE_DICT so it can be serialized properly.
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
# saves the model weights correctly (FSDP or otherwise)
trainer.save_model(config.output_dir)
if __name__ == "__main__":
main()