def train_model()

in train.py [0:0]


def train_model(accelerator, args):
    # Load the dataset based on the dataset_name argument
    dataset = get_dataset(
        accelerator=accelerator, dataset_id=args.dataset_id, num_proc=args.num_proc, cache_dir=args.cache_dir
    )
    with accelerator.main_process_first():
        splits = dataset.train_test_split(0.1, seed=2025)
    train_dataset, val_dataset = splits["train"], splits["test"]
    with accelerator.main_process_first():
        further_splits = val_dataset.train_test_split(0.1, seed=2025)
    val_dataset = further_splits["train"]

    # Load the model and processor
    ft_model = AutoModelForCausalLM.from_pretrained(args.model_id, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)

    # Freeze the vision tower if needed
    if args.freeze_vision_tower:
        for param in ft_model.vision_tower.parameters():
            param.requires_grad = False

    # LoRA config.
    if args.use_lora:
        TARGET_MODULES = ["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"]

        config = LoraConfig(
            r=8,
            lora_alpha=8,
            target_modules=TARGET_MODULES,
            task_type="CAUSAL_LM",
            lora_dropout=0.05,
            bias="none",
            inference_mode=False,
            use_rslora=True,
            init_lora_weights="gaussian",
        )
        ft_model = get_peft_model(ft_model, config)

    # Saving and loading hooks.
    def save_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            for model in models:
                if isinstance(accelerator.unwrap_model(model), type(accelerator.unwrap_model(ft_model))):
                    model = accelerator.unwrap_model(model)
                    model.save_pretrained(output_dir)
                else:
                    raise ValueError(f"unexpected save model: {model.__class__}")

                # make sure to pop weight so that corresponding model is not saved again
                if weights:
                    weights.pop()

    def load_model_hook(models, input_dir):
        transformer_ = None
        if not accelerator.distributed_type == DistributedType.DEEPSPEED:
            while len(models) > 0:
                model = models.pop()

                if isinstance(accelerator.unwrap_model(model), type(accelerator.unwrap_model(ft_model))):
                    transformer_ = model  # noqa: F841
                else:
                    raise ValueError(f"unexpected save model: {accelerator.unwrap_model(model).__class__}")

        else:
            transformer_ = AutoModelForCausalLM.from_pretrained(input_dir)  # noqa: F841

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    if args.gradient_checkpointing:
        ft_model.gradient_checkpointing_enable()

    # Create DataLoaders
    train_loader, val_loader = create_data_loaders(
        train_dataset,
        val_dataset,
        args.batch_size,
        args.num_proc,
        processor,
    )

    # Optimizer and scheduler
    optimizer_cls = torch.optim.AdamW
    if args.use_8bit_adam:
        import bitsandbytes as bnb

        optimizer_cls = bnb.optim.AdamW8bit

    trainable_params = list(filter(lambda p: p.requires_grad, ft_model.parameters()))
    optimizer = optimizer_cls(trainable_params, lr=args.lr)

    # Math around scheduler steps and training steps.
    len_train_dataloader_after_sharding = math.ceil(len(train_loader) / accelerator.num_processes)
    num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=math.ceil(num_update_steps_per_epoch * args.epochs) * accelerator.num_processes,
    )
    ft_model, train_loader, val_loader, optimizer, lr_scheduler = accelerator.prepare(
        ft_model, train_loader, val_loader, optimizer, lr_scheduler
    )

    # Update again if needed.
    num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
    max_train_steps = args.epochs * num_update_steps_per_epoch
    args.epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            logger.info(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch
    else:
        initial_global_step = 0

    # Start training!
    progress_bar = tqdm(
        range(0, max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )
    weight_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.float32
    global_step = 0
    first_epoch = 0

    for epoch in range(first_epoch, args.epochs):
        # Training phase
        ft_model.train()
        for batch in train_loader:
            with accelerator.accumulate(ft_model):
                # Prepare the input and target tensors
                color_inputs, colors = batch["color_inputs"], batch["colors"]
                lighting_inputs, lightings = batch["lighting_inputs"], batch["lightings"]
                lighting_type_inputs, lighting_types = batch["lighting_type_inputs"], batch["lighting_types"]
                composition_inputs, compositions = batch["composition_inputs"], batch["compositions"]

                losses = []
                for inputs, labels in [
                    (color_inputs, colors),
                    (lighting_inputs, lightings),
                    (lighting_type_inputs, lighting_types),
                    (composition_inputs, compositions),
                ]:
                    losses.append(forward_with_model(ft_model, inputs, labels, weight_dtype=weight_dtype).loss)

                loss = torch.stack(losses).mean()
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = ft_model.parameters()
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
                    if global_step % args.save_steps == 0:
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            val_loss = None
            if global_step % args.eval_steps == 0:
                val_loss = evaluate_model(
                    ft_model,
                    val_loader,
                    accelerator.device,
                    global_step,
                    args.max_val_item_count,
                    weight_dtype=weight_dtype,
                    disable_pbar=not accelerator.is_local_main_process,
                )

            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            if val_loss:
                logs.update({"val_loss": val_loss})
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

        if global_step >= max_train_steps:
            break

        evaluate_model(
            ft_model,
            val_loader,
            accelerator.device,
            global_step,
            args.max_val_item_count,
            weight_dtype=weight_dtype,
            disable_pbar=not accelerator.is_local_main_process,
        )

    # Finish run.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        accelerator.unwrap_model(ft_model).save_pretrained(args.output_dir)
        processor.save_pretrained(args.output_dir)
    accelerator.end_training()