in vision/smolvlm2/smolvlm/train/smolvlm_trainer.py [0:0]
def create_optimizer(self):
if self.optimizer is not None:
return self.optimizer # Already created
# Deepspeed or SageMaker MP users can rely on parent's create_optimizer
# (which then calls this if needed)
model = self.model
args = self.args
# Collect param names that should receive weight decay
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [n for n in decay_parameters if "bias" not in n]
# Prepare param groups
vision_params = []
connector_params = []
llm_params = []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# Decide group
if "vision_model" in n:
vision_params.append(n)
elif "connector" in n:
connector_params.append(n)
else:
llm_params.append(n)
# We'll build up param groups based on user-defined LR
# If e.g. vision_tower_lr=0 => we do not train the vision tower
# or you can skip the param group if LR=0
def make_group(param_names, lr_value):
# returns two subgroups: {decay: True}, {decay: False}
# so that weight decay is only applied for non-bias,non-LN
if lr_value <= 0:
return []
decay = {
"params": [p for n, p in model.named_parameters()
if n in param_names and n in decay_parameters],
"weight_decay": args.weight_decay,
"lr": lr_value,
}
no_decay = {
"params": [p for n, p in model.named_parameters()
if n in param_names and n not in decay_parameters],
"weight_decay": 0.0,
"lr": lr_value,
}
return [decay, no_decay]
groups = []
groups += make_group(vision_params, args.vision_tower_lr)
groups += make_group(connector_params, args.connector_lr)
groups += make_group(llm_params, args.language_model_lr)
# Fallback if no param groups are created (e.g. all lrs=0).
if not groups:
logger.warning("No param groups found. Possibly all LRs=0 or no requires_grad. "
"Falling back to default group.")
groups = [{"params": [p for p in model.parameters() if p.requires_grad],
"weight_decay": args.weight_decay,
"lr": args.learning_rate}]
# Function to log details of each parameter group
def log_param_groups(groups: List[Dict[str, Any]]):
logger.info("Parameter Groups Configuration:")
for group in groups:
group_name = group.get("name", "unnamed_group")
num_params = len(group["params"])
weight_decay = group.get("weight_decay", 0.0)
lr = group.get("lr", 0.0)
logger.info(
f" - Group '{group_name}': "
f"Number of Params = {num_params}, "
f"Weight Decay = {weight_decay}, "
f"Learning Rate = {lr}"
)
# Log the parameter groups
log_param_groups(groups)
# Let HF parse the correct optimizer class
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
self.optimizer = optimizer_cls(groups, **optimizer_kwargs)
return self.optimizer