def train()

in ultravox/training/train.py [0:0]


def train(config: config_base.TrainConfig):
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    is_master = local_rank == 0
    is_distributed = world_size > 1

    # DDP blows up logging, so this is an attempt to suppress it to only logs from the master process
    logging.basicConfig(level=logging.INFO if is_master else logging.ERROR)
    # os.environ["TORCH_LOGS"] = "ERROR" if is_master else "WARNING"
    transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)
    hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR)

    if is_distributed:
        torch.distributed.init_process_group(backend="nccl")

    with ddp_utils.run_on_master_first(is_master):
        # For larger models, we assume that the weights are already downloaded via prefetch_weights.py
        # Otherwise the barrier call can timeout.
        # This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run.
        prefetch_weights.download_weights(
            [config.text_model, config.audio_model], config.model_load_dir
        )

    logging.info("Instantiating model and processor...")

    model_load_context = (
        accelerate.init_empty_weights()
        if config.use_fsdp and not is_master
        else contextlib.nullcontext()
    )
    # If we're using FSDP, we can just initialize the model on the main process
    # and use sync_model_states to distribute the weights to the other processes.
    # Otherwise we'd be loading the model on every process, which uses too much CPU memory.
    with model_load_context:
        model_pack = model_types.create_model_pack(config)
        model = model_pack.model

    logging.info("Model and processor instantiated.")

    # Starting W&B. HF Trainer can also do this, but this way we can include the config.
    # Initializing sooner also means more of the stdout logs are captured by W&B.
    if "wandb" in config.report_logs_to and is_master:
        wandb.init(
            project=os.getenv("WANDB_PROJECT", "ultravox"),
            config=dataclasses.asdict(config),
            name=config.exp_name,
            dir="runs",
            tags=config.run_tags,
            save_code=True,
        )

    if config.model_load_dir:
        logging.info(f"Loading model state dict from {config.model_load_dir}")
        load_path = file_utils.download_dir_if_needed(config.model_load_dir)
        if os.path.isdir(load_path):
            load_path = os.path.join(load_path, "model*.safetensors")
        paths = glob.glob(load_path)
        assert len(paths) > 0, f"No model files found at {load_path}"
        for path in paths:
            state_dict = safetensors.torch.load_file(path)
            mismatch = model.load_state_dict(state_dict, strict=False)
            if mismatch.unexpected_keys:
                raise ValueError(
                    f"Unexpected keys in state dict: {mismatch.unexpected_keys}"
                )

    if config.ignore_data_skip and config.resume_from_load_dir:
        new_shuffle_seed = random.randint(1000, 1999)
        logging.info(
            "Since data skipping is ignored when resuming from a checkpoint,"
            f" randomly setting the train dataset seed to {new_shuffle_seed}."
        )
        config.train_dataset_args.shuffle_seed = new_shuffle_seed
        if wandb.run:
            wandb.run.config.update(
                {"train_dataset_args": dataclasses.asdict(config.train_dataset_args)},
                allow_val_change=True,
            )

    model.print_trainable_parameters()

    if not config.use_fsdp:
        # Moving to device in FSDP is handled by the Trainer
        model.to(device=torch.device(config.device, index=local_rank))
        logging.info(f"Using device (world_size): {model.device} ({world_size})")

    # Register custom datasets
    datasets.register_datasets(config.get_data_sets())

    # Prepare dataset, subsetting if needed
    train_dataset: datasets.SizedIterableDataset
    val_datasets: Dict[str, datasets.SizedIterableDataset] = {}

    train_dataset = prepare_dataset(
        train_args=config,
        model_pack=model_pack,
        data_opts=config.get_train_sets(),
        data_args=config.train_dataset_args,
        verbose=is_master,
    )
    if is_master:
        for val_opt in config.get_val_sets():
            val_dataset = prepare_dataset(
                train_args=config,
                model_pack=model_pack,
                data_opts=[val_opt],
                data_args=config.val_dataset_args,
                verbose=is_master,
            )
            val_datasets[val_opt.name] = val_dataset
        logging.info(
            f"Loaded {len(config.train_sets)}) data sets, sample limit: {config.train_dataset_args.max_samples} (val sample limit: {config.val_dataset_args.max_samples})"
        )
    else:
        # When using DDP with split_batches=True, the primary process will distribute the batches to the workers
        # The point of this is to avoid unnecessary data processing/downloading in the workers.
        # When using epochs to train, emptydataset must have a length equal to the training set
        train_dataset = datasets.EmptyDataset(len(train_dataset))
        for val_opts in config.get_val_sets():
            val_datasets[val_opts.name] = datasets.EmptyDataset(
                config.val_dataset_args.max_samples or 1
            )

    logging.info(f"Config Params: {config}")
    trainer = transformers.Seq2SeqTrainer(
        model,
        train_dataset=train_dataset,
        eval_dataset=val_datasets,
        data_collator=model_pack.data_collator,
        processing_class=model_pack.processor,
        args=transformers.Seq2SeqTrainingArguments(
            dataloader_num_workers=config.num_workers if is_master else 0,
            output_dir=config.output_dir,
            run_name=config.exp_name,
            optim=config.optimizer,
            num_train_epochs=config.num_epochs,
            max_steps=config.max_steps,
            eval_strategy="steps" if config.val_steps else "no",
            eval_steps=config.val_steps,
            save_strategy="steps" if config.save_steps else "no",
            save_steps=config.save_steps,
            logging_first_step=True,
            logging_dir=config.logs_dir,
            logging_steps=config.logging_steps,
            # TODO (Farzad): reconsider for multi-node
            # In DDP world_size is set to num_gpus and we want process-0 to split the batches
            per_device_train_batch_size=config.batch_size * world_size,
            accelerator_config={"split_batches": True},
            gradient_accumulation_steps=config.grad_accum_steps,
            eval_accumulation_steps=config.val_accum_steps,
            # tf32=dtype == torch.float32 and device.type == "cuda",  # TODO: check for Ampere GPU not just CUDA
            ddp_find_unused_parameters=False,
            learning_rate=config.lr,
            lr_scheduler_type=config.lr_scheduler,
            lr_scheduler_kwargs=config.lr_scheduler_kwargs,
            warmup_steps=0 if config.lr_warmup_steps < 1 else config.lr_warmup_steps,
            warmup_ratio=config.lr_warmup_steps if config.lr_warmup_steps < 1 else 0,
            weight_decay=config.weight_decay,
            # fp16=dtype == torch.float16,
            # bf16=dtype == torch.bfloat16,
            use_cpu=config.device == "cpu",
            seed=config.seed + local_rank,
            report_to=config.report_logs_to,
            # torch_compile=True,
            fsdp="full_shard auto_wrap" if config.use_fsdp else "",
            fsdp_config={
                "backward_prefetch": "backward_pre",
                "auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
            },
        ),
    )

    caught_exception = None
    if config.do_train:
        # Training loop
        logging.info("Starting training...")
        t_start = datetime.datetime.now()
        logging.info(f"train start time: {t_start}")

        if config.val_steps:
            if config.use_fsdp:
                logging.warning(
                    "FSDP is enabled: Skipping initial validation since model is not initialized."
                )
            else:
                trainer.evaluate()

        try:
            resume_from_checkpoint = load_path if config.resume_from_load_dir else None
            trainer.train(resume_from_checkpoint=resume_from_checkpoint)
        except Exception as e:
            logging.error(f"[rank: {local_rank}] Training failed with error: {e}")
            logging.error(f"[rank: {local_rank}] {traceback.format_exc()}")
            caught_exception = e

        t_end = datetime.datetime.now()
        logging.info(f"train end time: {t_end}")
        logging.info(f"elapsed: {t_end - t_start}")

    # save_final_model(trainer, model_pack, config)

    # use fixie-ai/evals for evaluation if in use_fsdp mode
    if config.do_eval:
        if config.model_type == "lsm":
            logging.warning("Evaluation is not supported for LSM models, skipping")
        if config.use_fsdp:
            logging.warning("Evaluation is not supported in FSDP mode, skipping")
        else:
            logging.info("Starting evaluation...")
            t_start = datetime.datetime.now()
            logging.info(f"eval start time: {t_start}")

            # Merge LoRA weights for better inference performance.
            # Note: this is irreversible and changes model saving format
            model.merge_and_unload()
            # changing padding side to left for inference
            model_pack.change_text_padding_side("left")
            inference = infer.LocalInference(
                model=model,
                processor=model_pack.processor,
                tokenizer=model_pack.get_text_tokenizer(),
                device=(
                    f"{config.device}:{local_rank}" if world_size > 1 else config.device
                ),
                dtype=device_helpers.get_dtype(config.data_type),
            )

            metrics, output_files = eval.eval_datasets(
                inference,
                config.get_eval_sets(),
                config.eval_dataset_args,
                config.eval_batch_size,
                config.eval_max_tokens,
                config.eval_temperature,
                config.output_dir,
            )
            if is_master:
                eval.print_results(metrics, output_files)

            t_end = datetime.datetime.now()
            logging.info(f"eval end time: {t_end}")
            logging.info(f"elapsed: {t_end - t_start}")

    # finish wandb run if it exists
    if wandb.run and is_master:
        wandb.run.finish(exit_code=1 if caught_exception else 0)
    # destroy process group if distributed training
    if world_size > 1:
        torch.distributed.destroy_process_group()

    if caught_exception:
        logging.error(
            f"[rank: {local_rank}] Training failed earlier, exiting and raising error."
        )
        raise caught_exception