vision/smolvlm2/smolvlm/train/train.py (252 lines of code) (raw):
import os
import math
import logging
import pathlib
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
import wandb
import torch
import transformers
from transformers import (
HfArgumentParser,
AutoConfig,
AutoProcessor,
TrainingArguments,
set_seed
)
import torch.distributed as dist
# LoRA / PEFT if needed
try:
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
except ImportError:
LoraConfig = None
get_peft_model = None
prepare_model_for_kbit_training = None
# BitsAndBytes if needed
try:
from transformers import BitsAndBytesConfig
except ImportError:
BitsAndBytesConfig = None
from smolvlm.train.smolvlm_trainer import SmolVLMTrainer
from smolvlm.train.args import DataArguments, ModelArguments, TrainingArguments
from smolvlm.datasets.builder import make_supervised_data_module
logger = logging.getLogger(__name__)
#TODO: check what these do.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
def trainer_save_model_safe(trainer: SmolVLMTrainer):
"""
Safely saves the model if not in DeepSpeed ZeRO stage-3.
"""
if trainer.is_deepspeed_enabled:
trainer.save_model()
else:
state_dict = trainer.model.state_dict()
cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
del state_dict
trainer._save(trainer.args.output_dir, state_dict=cpu_state_dict)
def get_nb_trainable_parameters(model: torch.nn.Module) -> Tuple[int, int]:
"""
Returns (trainable_params, total_params) across the entire model.
"""
trainable_params = 0
total_params = 0
for _, param in model.named_parameters():
total_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return trainable_params, total_params
def set_trainable_params(model: torch.nn.Module, training_args: TrainingArguments):
"""
Freezes all parameters, then selectively unfreezes based on user flags:
- tune_vision_tower => unfreeze vision tower
- tune_mm_connector => unfreeze connector
- tune_language_model => unfreeze base language model
Prints out which modules are unfrozen/frozen for clarity.
"""
for param_name, param in model.named_parameters():
param.requires_grad = False
vis_unfrozen, conn_unfrozen, llm_unfrozen = 0, 0, 0
vis_total, conn_total, llm_total = 0, 0, 0
for name, param in model.named_parameters():
if ("vision_model" in name) or ("vision_tower" in name):
vis_total += param.numel()
if training_args.tune_vision_tower:
param.requires_grad = True
vis_unfrozen += param.numel()
elif ("connector" in name) or ("modality_projection" in name) or ("merger" in name):
conn_total += param.numel()
if training_args.tune_mm_connector:
param.requires_grad = True
conn_unfrozen += param.numel()
else:
llm_total += param.numel()
if training_args.tune_language_model:
param.requires_grad = True
llm_unfrozen += param.numel()
# Summaries
trainable_params, total_params = get_nb_trainable_parameters(model)
pct = 100.0 * trainable_params / max(total_params, 1)
logger.info("=== Freeze/Unfreeze Summary ===")
logger.info(f" tune_vision_tower={training_args.tune_vision_tower} "
f"(unfrozen {vis_unfrozen:,d}/{vis_total:,d} params)")
logger.info(f" tune_mm_connector={training_args.tune_mm_connector} "
f"(unfrozen {conn_unfrozen:,d}/{conn_total:,d} params)")
logger.info(f" tune_language_model={training_args.tune_language_model} "
f"(unfrozen {llm_unfrozen:,d}/{llm_total:,d} params)")
logger.info(f" => Overall trainable params: {trainable_params:,d} / {total_params:,d} "
f"({pct:.2f}%)\n")
def enable_gradient_checkpointing(model: torch.nn.Module, training_args: TrainingArguments):
"""
Enables gradient checkpointing if specified in TrainingArguments.
If the model's LLM submodule supports enabling input grads, we do so (like Qwen does).
Otherwise, we attach a forward hook to the input embedding to require grads.
"""
logger.info("Enabling gradient checkpointing in the model.")
model.config.use_cache = False
model.config.use_reentrant_checkpointing = False
# model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "model") and hasattr(model.model, "llm"):
# Some models have a method like `enable_input_require_grads`.
if hasattr(model.model.llm, "enable_input_require_grads"):
logger.info("Calling model.model.llm.enable_input_require_grads() for better GC.")
model.model.llm.enable_input_require_grads()
else:
# fallback approach
logger.info("Attaching a forward hook to require grad on embeddings output.")
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
input_embed = model.get_input_embeddings() if hasattr(model, "get_input_embeddings") else None
if input_embed is not None:
input_embed.register_forward_hook(make_inputs_require_grad)
def prepare_model(
model_args: ModelArguments,
training_args: TrainingArguments
):
"""
Loads and configures the smolVLM model (Idefics3ForConditionalGeneration),
applying rope scaling if needed, plus optional bitsandbytes quant config.
"""
logger.info("Loading config from %s", model_args.model_name_or_path)
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
compute_dtype = (
torch.float16 if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
# Possibly adjust rope scaling for a longer context
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
factor = math.ceil(training_args.model_max_length / orig_ctx_len)
config.rope_scaling = {"type": "linear", "factor": factor}
logger.info(f"Auto rope scaling => from {orig_ctx_len} to {training_args.model_max_length}. Factor={factor}")
# For training, disable cache to reduce memory usage
config.use_cache = False
bnb_args = {}
# If using bitsandbytes in 4- or 8-bit
if BitsAndBytesConfig is not None and training_args.bits in [4, 8]:
bnb_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=(training_args.bits == 4),
load_in_8bit=(training_args.bits == 8),
llm_int8_skip_modules=["lm_head"],
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type,
)
logger.info(f"Using bitsandbytes quantization: bits={training_args.bits}, type={training_args.quant_type}")
# Possibly set attention impl
# if training_args.disable_flash_attn2:
# attn_impl = "sdpa"
# else:
# attn_impl = "flash_attention_2"
# logger.info("Instantiating SmolVLMForConditionalGeneration with attention impl=%s", attn_impl)
if training_args.disable_flash_attn2:
config._attn_implementation = "sdpa"
else:
config._attn_implementation = "flash_attention_2"
if model_args.frames_per_clip > 1:
from smolvlm.model.modeling_smollmm import SmolLMMForConditionalGeneration
config.frames_per_clip = model_args.frames_per_clip
model_cls = SmolLMMForConditionalGeneration
logger.info(f"Using frame emmbedding averaging of {model_args.frames_per_clip} frames")
else:
from smolvlm.model.modeling_smolvlm import SmolVLMForConditionalGeneration
model_cls = SmolVLMForConditionalGeneration
model = model_cls.from_pretrained(
model_args.model_name_or_path,
torch_dtype=compute_dtype,
config = config,
**bnb_args,
)
return model
def apply_peft(model: torch.nn.Module, training_args: TrainingArguments) -> torch.nn.Module:
"""
Applies LoRA/PEFT if training_args.peft_enable is True.
Also calls `prepare_model_for_kbit_training` if bits=4 or 8.
"""
if (LoraConfig is None) or (get_peft_model is None):
raise ValueError("PEFT is not installed, but peft_enable=True was set.")
logger.info("PEFT/LoRA is enabled. Building LoRA config...")
# If user hasn't provided specific modules, pick a guess
peft_target_modules = training_args.target_modules
if not peft_target_modules:
peft_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
# For 4bit/8bit, we can do some prep:
if training_args.bits in [4, 8] and prepare_model_for_kbit_training is not None:
logger.info("Running `prepare_model_for_kbit_training` for LoRA + 4/8-bit support.")
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=training_args.gradient_checkpointing
)
lora_config = LoraConfig(
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
target_modules=peft_target_modules,
bias=training_args.lora_bias, # "none"/"all"/"lora_only"
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
logger.info("LoRA applied. Trainable parameters:")
model.print_trainable_parameters()
return model
def auto_resume_or_start(training_args: TrainingArguments) -> bool:
"""
Detect if there's a previous checkpoint to resume from.
Return True if we found a checkpoint. Otherwise, we start fresh.
"""
ckpts = list(pathlib.Path(training_args.output_dir).glob("checkpoint-*"))
return len(ckpts) > 0
def train():
"""
Main fine-tuning entry point for your smolVLM model,
with optional LoRA + bitsandbytes, and prints which submodules are frozen/unfrozen.
"""
logging.basicConfig(level=logging.INFO)
logger.info("Parsing arguments with HfArgumentParser...")
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if data_args.packed and model_args.apply_diagonal_block_attention:
from smolvlm.model.varlen_packing import apply_varlen_patch
apply_varlen_patch()
elif data_args.packed and not model_args.apply_diagonal_block_attention:
logger.warn("Sequence packing has being enabled WITHOUGH diagonal block attention!")
elif not data_args.packed and model_args.apply_diagonal_block_attention:
logger.warn("diagonal block attention has been enabled WITHOUT sequence packing. Ignoring flag!")
# Ensure reproducibility
set_seed(training_args.seed)
# Initialize wandb only on the main process (global rank 0) if wandb logging is enabled
if "wandb" in training_args.report_to and training_args.local_rank == 0 and dist.get_rank() == 0:
os.environ["WANDB_PROJECT"] = "smolvlmvideo" # Set project name
wandb.init(
name=training_args.run_name,
config=training_args.to_dict(),
)
# Ensure other processes will not try to log
os.environ["WANDB_MODE"] = "offline"
# Possibly set tune flags automatically based on user-provided LR
training_args.tune_language_model = training_args.language_model_lr > 1e-9
training_args.tune_mm_connector = training_args.connector_lr > 1e-9
training_args.tune_vision_tower = training_args.vision_tower_lr > 1e-9
# 1) Prepare model + config
logger.info("Preparing model + config (possibly with bitsandbytes) ...")
model = prepare_model(model_args, training_args)
# 2) Freeze/unfreeze based on user flags, plus prints
set_trainable_params(model, training_args)
# 3) Possibly enable gradient checkpointing
if training_args.gradient_checkpointing:
enable_gradient_checkpointing(model, training_args)
# 4) Possibly apply LoRA/PEFT
if training_args.peft_enable:
model = apply_peft_if_needed(model, training_args)
# 5) Load processor (tokenizer + image processor, etc.)
#import ipdb; ipdb.set_trace()
logger.info("Loading AutoProcessor from %s", model_args.model_name_or_path)
if model_args.frames_per_clip > 1:
from smolvlm.model.processing_smollmm import SmolLMMProcessor
processor = SmolLMMProcessor.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side=model_args.padding_side,
trust_remote_code=model_args.trust_remote_code,
)
else:
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side=model_args.padding_side,
trust_remote_code=model_args.trust_remote_code,
)
# 6) Build dataset + collator
logger.info("Building dataset + collator...")
data_module = make_supervised_data_module(processor, data_args, training_args, model_args)
# 7) Initialize custom trainer
logger.info("Initializing SmolVLMTrainer...")
trainer = SmolVLMTrainer(
model=model,
args=training_args,
**data_module
)
# 8) Possibly auto-resume from checkpoint
resume_training = auto_resume_or_start(training_args)
if resume_training:
logger.info("Resuming from a previous checkpoint in %s ...", training_args.output_dir)
trainer.train(resume_from_checkpoint=True)
else:
logger.info("Starting a fresh training run...")
trainer.train()
# 9) Post-training final steps
logger.info("Training completed. Saving final model...")
# Re-enable model cache if needed
model.config.use_cache = True
# Save trainer state
trainer.save_state()
# Save final model (special logic if in Deepspeed)
if trainer.is_deepspeed_enabled:
trainer.save_model()
else:
trainer_save_model_safe(trainer)
logger.info("All done. Exiting successfully.")
if __name__ == "__main__":
train()