vision/smolvlm2/smolvlm/train/smolvlm_trainer.py (85 lines of code) (raw):
import os
import torch
import logging
import torch.nn as nn
from typing import Optional, Dict, Any, List
from transformers import Trainer, PreTrainedModel
from transformers.trainer import get_parameter_names, ALL_LAYERNORM_LAYERS
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
logger = logging.getLogger(__name__)
class SmolVLMTrainer(Trainer):
"""
A specialized Trainer that supports:
- Distinct LR for vision tower vs. connector vs. LLM.
- Save model logic that can handle large models or PEFT.
"""
def create_optimizer(self):
if self.optimizer is not None:
return self.optimizer # Already created
# Deepspeed or SageMaker MP users can rely on parent's create_optimizer
# (which then calls this if needed)
model = self.model
args = self.args
# Collect param names that should receive weight decay
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [n for n in decay_parameters if "bias" not in n]
# Prepare param groups
vision_params = []
connector_params = []
llm_params = []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# Decide group
if "vision_model" in n:
vision_params.append(n)
elif "connector" in n:
connector_params.append(n)
else:
llm_params.append(n)
# We'll build up param groups based on user-defined LR
# If e.g. vision_tower_lr=0 => we do not train the vision tower
# or you can skip the param group if LR=0
def make_group(param_names, lr_value):
# returns two subgroups: {decay: True}, {decay: False}
# so that weight decay is only applied for non-bias,non-LN
if lr_value <= 0:
return []
decay = {
"params": [p for n, p in model.named_parameters()
if n in param_names and n in decay_parameters],
"weight_decay": args.weight_decay,
"lr": lr_value,
}
no_decay = {
"params": [p for n, p in model.named_parameters()
if n in param_names and n not in decay_parameters],
"weight_decay": 0.0,
"lr": lr_value,
}
return [decay, no_decay]
groups = []
groups += make_group(vision_params, args.vision_tower_lr)
groups += make_group(connector_params, args.connector_lr)
groups += make_group(llm_params, args.language_model_lr)
# Fallback if no param groups are created (e.g. all lrs=0).
if not groups:
logger.warning("No param groups found. Possibly all LRs=0 or no requires_grad. "
"Falling back to default group.")
groups = [{"params": [p for p in model.parameters() if p.requires_grad],
"weight_decay": args.weight_decay,
"lr": args.learning_rate}]
# Function to log details of each parameter group
def log_param_groups(groups: List[Dict[str, Any]]):
logger.info("Parameter Groups Configuration:")
for group in groups:
group_name = group.get("name", "unnamed_group")
num_params = len(group["params"])
weight_decay = group.get("weight_decay", 0.0)
lr = group.get("lr", 0.0)
logger.info(
f" - Group '{group_name}': "
f"Number of Params = {num_params}, "
f"Weight Decay = {weight_decay}, "
f"Learning Rate = {lr}"
)
# Log the parameter groups
log_param_groups(groups)
# Let HF parse the correct optimizer class
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
self.optimizer = optimizer_cls(groups, **optimizer_kwargs)
return self.optimizer
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False, state_dict=None):
"""
Saves the model. Supports big models or PEFT seamlessly if not using DeepSpeed.
"""
if output_dir is None:
output_dir = self.args.output_dir
# If the user is using DeepSpeed, super().save_model handles the Zero partitions
if self.is_deepspeed_enabled:
super().save_model(output_dir, _internal_call=_internal_call)
return
# If we have state_dict, use it; else gather from self.model
if state_dict is None:
if hasattr(self.model, "state_dict"):
state_dict = self.model.state_dict()
else:
# PEFT adapter has `get_base_model`, or it's a normal model
state_dict = PreTrainedModel.unwrap_model(self.model).state_dict()
# Let model handle the actual serialization
if self.args.should_save:
# typical structure: yourmodel.save_pretrained(output_dir, state_dict=state_dict)
self.model.save_pretrained(output_dir, state_dict=state_dict)