in train.py [0:0]
def train_model(accelerator, args):
# Load the dataset based on the dataset_name argument
dataset = get_dataset(
accelerator=accelerator, dataset_id=args.dataset_id, num_proc=args.num_proc, cache_dir=args.cache_dir
)
with accelerator.main_process_first():
splits = dataset.train_test_split(0.1, seed=2025)
train_dataset, val_dataset = splits["train"], splits["test"]
with accelerator.main_process_first():
further_splits = val_dataset.train_test_split(0.1, seed=2025)
val_dataset = further_splits["train"]
# Load the model and processor
ft_model = AutoModelForCausalLM.from_pretrained(args.model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
# Freeze the vision tower if needed
if args.freeze_vision_tower:
for param in ft_model.vision_tower.parameters():
param.requires_grad = False
# LoRA config.
if args.use_lora:
TARGET_MODULES = ["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"]
config = LoraConfig(
r=8,
lora_alpha=8,
target_modules=TARGET_MODULES,
task_type="CAUSAL_LM",
lora_dropout=0.05,
bias="none",
inference_mode=False,
use_rslora=True,
init_lora_weights="gaussian",
)
ft_model = get_peft_model(ft_model, config)
# Saving and loading hooks.
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
if isinstance(accelerator.unwrap_model(model), type(accelerator.unwrap_model(ft_model))):
model = accelerator.unwrap_model(model)
model.save_pretrained(output_dir)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
def load_model_hook(models, input_dir):
transformer_ = None
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(accelerator.unwrap_model(model), type(accelerator.unwrap_model(ft_model))):
transformer_ = model # noqa: F841
else:
raise ValueError(f"unexpected save model: {accelerator.unwrap_model(model).__class__}")
else:
transformer_ = AutoModelForCausalLM.from_pretrained(input_dir) # noqa: F841
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing:
ft_model.gradient_checkpointing_enable()
# Create DataLoaders
train_loader, val_loader = create_data_loaders(
train_dataset,
val_dataset,
args.batch_size,
args.num_proc,
processor,
)
# Optimizer and scheduler
optimizer_cls = torch.optim.AdamW
if args.use_8bit_adam:
import bitsandbytes as bnb
optimizer_cls = bnb.optim.AdamW8bit
trainable_params = list(filter(lambda p: p.requires_grad, ft_model.parameters()))
optimizer = optimizer_cls(trainable_params, lr=args.lr)
# Math around scheduler steps and training steps.
len_train_dataloader_after_sharding = math.ceil(len(train_loader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=math.ceil(num_update_steps_per_epoch * args.epochs) * accelerator.num_processes,
)
ft_model, train_loader, val_loader, optimizer, lr_scheduler = accelerator.prepare(
ft_model, train_loader, val_loader, optimizer, lr_scheduler
)
# Update again if needed.
num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
max_train_steps = args.epochs * num_update_steps_per_epoch
args.epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
args.resume_from_checkpoint = None
initial_global_step = 0
else:
logger.info(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
# Start training!
progress_bar = tqdm(
range(0, max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
weight_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.float32
global_step = 0
first_epoch = 0
for epoch in range(first_epoch, args.epochs):
# Training phase
ft_model.train()
for batch in train_loader:
with accelerator.accumulate(ft_model):
# Prepare the input and target tensors
color_inputs, colors = batch["color_inputs"], batch["colors"]
lighting_inputs, lightings = batch["lighting_inputs"], batch["lightings"]
lighting_type_inputs, lighting_types = batch["lighting_type_inputs"], batch["lighting_types"]
composition_inputs, compositions = batch["composition_inputs"], batch["compositions"]
losses = []
for inputs, labels in [
(color_inputs, colors),
(lighting_inputs, lightings),
(lighting_type_inputs, lighting_types),
(composition_inputs, compositions),
]:
losses.append(forward_with_model(ft_model, inputs, labels, weight_dtype=weight_dtype).loss)
loss = torch.stack(losses).mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = ft_model.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
val_loss = None
if global_step % args.eval_steps == 0:
val_loss = evaluate_model(
ft_model,
val_loader,
accelerator.device,
global_step,
args.max_val_item_count,
weight_dtype=weight_dtype,
disable_pbar=not accelerator.is_local_main_process,
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
if val_loss:
logs.update({"val_loss": val_loss})
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= max_train_steps:
break
evaluate_model(
ft_model,
val_loader,
accelerator.device,
global_step,
args.max_val_item_count,
weight_dtype=weight_dtype,
disable_pbar=not accelerator.is_local_main_process,
)
# Finish run.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.unwrap_model(ft_model).save_pretrained(args.output_dir)
processor.save_pretrained(args.output_dir)
accelerator.end_training()