in ultravox/training/train.py [0:0]
def train(config: config_base.TrainConfig):
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_master = local_rank == 0
is_distributed = world_size > 1
# DDP blows up logging, so this is an attempt to suppress it to only logs from the master process
logging.basicConfig(level=logging.INFO if is_master else logging.ERROR)
# os.environ["TORCH_LOGS"] = "ERROR" if is_master else "WARNING"
transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)
hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)
if is_distributed:
torch.distributed.init_process_group(backend="nccl")
with ddp_utils.run_on_master_first(is_master):
# For larger models, we assume that the weights are already downloaded via prefetch_weights.py
# Otherwise the barrier call can timeout.
# This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run.
prefetch_weights.download_weights(
[config.text_model, config.audio_model], config.model_load_dir
)
logging.info("Instantiating model and processor...")
model_load_context = (
accelerate.init_empty_weights()
if config.use_fsdp and not is_master
else contextlib.nullcontext()
)
# If we're using FSDP, we can just initialize the model on the main process
# and use sync_model_states to distribute the weights to the other processes.
# Otherwise we'd be loading the model on every process, which uses too much CPU memory.
with model_load_context:
model_pack = model_types.create_model_pack(config)
model = model_pack.model
logging.info("Model and processor instantiated.")
# Starting W&B. HF Trainer can also do this, but this way we can include the config.
# Initializing sooner also means more of the stdout logs are captured by W&B.
if "wandb" in config.report_logs_to and is_master:
wandb.init(
project=os.getenv("WANDB_PROJECT", "ultravox"),
config=dataclasses.asdict(config),
name=config.exp_name,
dir="runs",
tags=config.run_tags,
save_code=True,
)
if config.model_load_dir:
logging.info(f"Loading model state dict from {config.model_load_dir}")
load_path = file_utils.download_dir_if_needed(config.model_load_dir)
if os.path.isdir(load_path):
load_path = os.path.join(load_path, "model*.safetensors")
paths = glob.glob(load_path)
assert len(paths) > 0, f"No model files found at {load_path}"
for path in paths:
state_dict = safetensors.torch.load_file(path)
mismatch = model.load_state_dict(state_dict, strict=False)
if mismatch.unexpected_keys:
raise ValueError(
f"Unexpected keys in state dict: {mismatch.unexpected_keys}"
)
if config.ignore_data_skip and config.resume_from_load_dir:
new_shuffle_seed = random.randint(1000, 1999)
logging.info(
"Since data skipping is ignored when resuming from a checkpoint,"
f" randomly setting the train dataset seed to {new_shuffle_seed}."
)
config.train_dataset_args.shuffle_seed = new_shuffle_seed
if wandb.run:
wandb.run.config.update(
{"train_dataset_args": dataclasses.asdict(config.train_dataset_args)},
allow_val_change=True,
)
model.print_trainable_parameters()
if not config.use_fsdp:
# Moving to device in FSDP is handled by the Trainer
model.to(device=torch.device(config.device, index=local_rank))
logging.info(f"Using device (world_size): {model.device} ({world_size})")
# Register custom datasets
datasets.register_datasets(config.get_data_sets())
# Prepare dataset, subsetting if needed
train_dataset: datasets.SizedIterableDataset
val_datasets: Dict[str, datasets.SizedIterableDataset] = {}
train_dataset = prepare_dataset(
train_args=config,
model_pack=model_pack,
data_opts=config.get_train_sets(),
data_args=config.train_dataset_args,
verbose=is_master,
)
if is_master:
for val_opt in config.get_val_sets():
val_dataset = prepare_dataset(
train_args=config,
model_pack=model_pack,
data_opts=[val_opt],
data_args=config.val_dataset_args,
verbose=is_master,
)
val_datasets[val_opt.name] = val_dataset
logging.info(
f"Loaded {len(config.train_sets)}) data sets, sample limit: {config.train_dataset_args.max_samples} (val sample limit: {config.val_dataset_args.max_samples})"
)
else:
# When using DDP with split_batches=True, the primary process will distribute the batches to the workers
# The point of this is to avoid unnecessary data processing/downloading in the workers.
# When using epochs to train, emptydataset must have a length equal to the training set
train_dataset = datasets.EmptyDataset(len(train_dataset))
for val_opts in config.get_val_sets():
val_datasets[val_opts.name] = datasets.EmptyDataset(
config.val_dataset_args.max_samples or 1
)
logging.info(f"Config Params: {config}")
trainer = transformers.Seq2SeqTrainer(
model,
train_dataset=train_dataset,
eval_dataset=val_datasets,
data_collator=model_pack.data_collator,
processing_class=model_pack.processor,
args=transformers.Seq2SeqTrainingArguments(
dataloader_num_workers=config.num_workers if is_master else 0,
output_dir=config.output_dir,
run_name=config.exp_name,
optim=config.optimizer,
num_train_epochs=config.num_epochs,
max_steps=config.max_steps,
eval_strategy="steps" if config.val_steps else "no",
eval_steps=config.val_steps,
save_strategy="steps" if config.save_steps else "no",
save_steps=config.save_steps,
logging_first_step=True,
logging_dir=config.logs_dir,
logging_steps=config.logging_steps,
# TODO (Farzad): reconsider for multi-node
# In DDP world_size is set to num_gpus and we want process-0 to split the batches
per_device_train_batch_size=config.batch_size * world_size,
accelerator_config={"split_batches": True},
gradient_accumulation_steps=config.grad_accum_steps,
eval_accumulation_steps=config.val_accum_steps,
# tf32=dtype == torch.float32 and device.type == "cuda", # TODO: check for Ampere GPU not just CUDA
ddp_find_unused_parameters=False,
learning_rate=config.lr,
lr_scheduler_type=config.lr_scheduler,
lr_scheduler_kwargs=config.lr_scheduler_kwargs,
warmup_steps=0 if config.lr_warmup_steps < 1 else config.lr_warmup_steps,
warmup_ratio=config.lr_warmup_steps if config.lr_warmup_steps < 1 else 0,
weight_decay=config.weight_decay,
# fp16=dtype == torch.float16,
# bf16=dtype == torch.bfloat16,
use_cpu=config.device == "cpu",
seed=config.seed + local_rank,
report_to=config.report_logs_to,
# torch_compile=True,
fsdp="full_shard auto_wrap" if config.use_fsdp else "",
fsdp_config={
"backward_prefetch": "backward_pre",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
),
)
caught_exception = None
if config.do_train:
# Training loop
logging.info("Starting training...")
t_start = datetime.datetime.now()
logging.info(f"train start time: {t_start}")
if config.val_steps:
if config.use_fsdp:
logging.warning(
"FSDP is enabled: Skipping initial validation since model is not initialized."
)
else:
trainer.evaluate()
try:
resume_from_checkpoint = load_path if config.resume_from_load_dir else None
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
except Exception as e:
logging.error(f"[rank: {local_rank}] Training failed with error: {e}")
logging.error(f"[rank: {local_rank}] {traceback.format_exc()}")
caught_exception = e
t_end = datetime.datetime.now()
logging.info(f"train end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")
# save_final_model(trainer, model_pack, config)
# use fixie-ai/evals for evaluation if in use_fsdp mode
if config.do_eval:
if config.model_type == "lsm":
logging.warning("Evaluation is not supported for LSM models, skipping")
if config.use_fsdp:
logging.warning("Evaluation is not supported in FSDP mode, skipping")
else:
logging.info("Starting evaluation...")
t_start = datetime.datetime.now()
logging.info(f"eval start time: {t_start}")
# Merge LoRA weights for better inference performance.
# Note: this is irreversible and changes model saving format
model.merge_and_unload()
# changing padding side to left for inference
model_pack.change_text_padding_side("left")
inference = infer.LocalInference(
model=model,
processor=model_pack.processor,
tokenizer=model_pack.get_text_tokenizer(),
device=(
f"{config.device}:{local_rank}" if world_size > 1 else config.device
),
dtype=device_helpers.get_dtype(config.data_type),
)
metrics, output_files = eval.eval_datasets(
inference,
config.get_eval_sets(),
config.eval_dataset_args,
config.eval_batch_size,
config.eval_max_tokens,
config.eval_temperature,
config.output_dir,
)
if is_master:
eval.print_results(metrics, output_files)
t_end = datetime.datetime.now()
logging.info(f"eval end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")
# finish wandb run if it exists
if wandb.run and is_master:
wandb.run.finish(exit_code=1 if caught_exception else 0)
# destroy process group if distributed training
if world_size > 1:
torch.distributed.destroy_process_group()
if caught_exception:
logging.error(
f"[rank: {local_rank}] Training failed earlier, exiting and raising error."
)
raise caught_exception