in vision/smolvlm2/smolvlm/train/train.py [0:0]
def train():
"""
Main fine-tuning entry point for your smolVLM model,
with optional LoRA + bitsandbytes, and prints which submodules are frozen/unfrozen.
"""
logging.basicConfig(level=logging.INFO)
logger.info("Parsing arguments with HfArgumentParser...")
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if data_args.packed and model_args.apply_diagonal_block_attention:
from smolvlm.model.varlen_packing import apply_varlen_patch
apply_varlen_patch()
elif data_args.packed and not model_args.apply_diagonal_block_attention:
logger.warn("Sequence packing has being enabled WITHOUGH diagonal block attention!")
elif not data_args.packed and model_args.apply_diagonal_block_attention:
logger.warn("diagonal block attention has been enabled WITHOUT sequence packing. Ignoring flag!")
# Ensure reproducibility
set_seed(training_args.seed)
# Initialize wandb only on the main process (global rank 0) if wandb logging is enabled
if "wandb" in training_args.report_to and training_args.local_rank == 0 and dist.get_rank() == 0:
os.environ["WANDB_PROJECT"] = "smolvlmvideo" # Set project name
wandb.init(
name=training_args.run_name,
config=training_args.to_dict(),
)
# Ensure other processes will not try to log
os.environ["WANDB_MODE"] = "offline"
# Possibly set tune flags automatically based on user-provided LR
training_args.tune_language_model = training_args.language_model_lr > 1e-9
training_args.tune_mm_connector = training_args.connector_lr > 1e-9
training_args.tune_vision_tower = training_args.vision_tower_lr > 1e-9
# 1) Prepare model + config
logger.info("Preparing model + config (possibly with bitsandbytes) ...")
model = prepare_model(model_args, training_args)
# 2) Freeze/unfreeze based on user flags, plus prints
set_trainable_params(model, training_args)
# 3) Possibly enable gradient checkpointing
if training_args.gradient_checkpointing:
enable_gradient_checkpointing(model, training_args)
# 4) Possibly apply LoRA/PEFT
if training_args.peft_enable:
model = apply_peft_if_needed(model, training_args)
# 5) Load processor (tokenizer + image processor, etc.)
#import ipdb; ipdb.set_trace()
logger.info("Loading AutoProcessor from %s", model_args.model_name_or_path)
if model_args.frames_per_clip > 1:
from smolvlm.model.processing_smollmm import SmolLMMProcessor
processor = SmolLMMProcessor.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side=model_args.padding_side,
trust_remote_code=model_args.trust_remote_code,
)
else:
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side=model_args.padding_side,
trust_remote_code=model_args.trust_remote_code,
)
# 6) Build dataset + collator
logger.info("Building dataset + collator...")
data_module = make_supervised_data_module(processor, data_args, training_args, model_args)
# 7) Initialize custom trainer
logger.info("Initializing SmolVLMTrainer...")
trainer = SmolVLMTrainer(
model=model,
args=training_args,
**data_module
)
# 8) Possibly auto-resume from checkpoint
resume_training = auto_resume_or_start(training_args)
if resume_training:
logger.info("Resuming from a previous checkpoint in %s ...", training_args.output_dir)
trainer.train(resume_from_checkpoint=True)
else:
logger.info("Starting a fresh training run...")
trainer.train()
# 9) Post-training final steps
logger.info("Training completed. Saving final model...")
# Re-enable model cache if needed
model.config.use_cache = True
# Save trainer state
trainer.save_state()
# Save final model (special logic if in Deepspeed)
if trainer.is_deepspeed_enabled:
trainer.save_model()
else:
trainer_save_model_safe(trainer)
logger.info("All done. Exiting successfully.")