in train.py [0:0]
def train(train_cfg, vlm_cfg):
train_loader, val_loader = get_dataloaders(train_cfg, vlm_cfg)
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template)
run_name = get_run_name(train_cfg, vlm_cfg)
total_dataset_size = len(train_loader.dataset)
if train_cfg.log_wandb and is_master():
if train_cfg.data_cutoff_idx is None:
run_name = run_name.replace("full_ds", f"{total_dataset_size}samples")
if train_cfg.log_wandb and is_master():
run = wandb.init(
entity=train_cfg.wandb_entity,
project="nanoVLM",
config={
"VLMConfig": asdict(vlm_cfg),
"TrainConfig": asdict(train_cfg)
},
name=run_name,
)
# Initialize model
if train_cfg.resume_from_vlm_checkpoint:
model = VisionLanguageModel.from_pretrained(vlm_cfg.vlm_checkpoint_path)
else:
model = VisionLanguageModel(vlm_cfg, load_backbone=vlm_cfg.vlm_load_backbone_weights)
if is_master():
print(f"nanoVLM initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Training summary{' (global)' if is_dist() else ''}: {len(train_loader.dataset)} samples, {int(len(train_loader)*get_world_size())} batches/epoch, batch size {int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}{', training on ' + str(get_world_size()) + ' GPUs' if is_dist() else ''}")
if is_dist():
print(f"Training summary per GPU: {len(train_loader)} batches/epoch, batch size {train_loader.batch_size}")
print(f"Validation summary{' (global)' if is_dist() else ''}: {len(val_loader.dataset)} samples, {int(len(val_loader)*get_world_size())} batches/epoch, batch size {int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}{', training on ' + str(get_world_size()) + ' GPUs' if is_dist() else ''}")
if is_dist():
print(f"Validation summary per GPU: {len(val_loader)} batches/epoch, batch size {val_loader.batch_size}")
# Define optimizer groups
# Since we have pretrained vision and language backbones, but a newly initialized modality projection layer, it doesn't make sense to train them with the same learning rate
# You could opt to fully freeze the backbones and only train the MP layer, but finetuning them with a lower learning rate makes the training as a whole easier
param_groups = [{'params': list(model.MP.parameters()), 'lr': train_cfg.lr_mp},
{'params': list(model.decoder.parameters()) + list(model.vision_encoder.parameters()), 'lr': train_cfg.lr_backbones}]
optimizer = optim.AdamW(param_groups)
all_params = [p for group in optimizer.param_groups for p in group['params']]
device = (
torch.device("cuda") if torch.cuda.is_available()
else torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
else torch.device("cpu")
)
if device.type == "mps":
torch.backends.mps.enable_fallback_to_cpu = True
torch.mps.empty_cache()
print(f"Using device: {device}")
model.to(device)
if train_cfg.compile:
model = torch.compile(model)
if is_dist():
model = wrap_model(model)
epoch_times = []
best_accuracy = 0
best_val_loss = float('inf')
global_step = 0
epoch = 0
# Training stats accumulators
accumulated_stats = {
'tokens_per_second': [],
'data_load_time': [],
'fw_bw_time': [],
'post_process_time': [],
'images_per_sample': [],
}
while global_step < train_cfg.max_training_steps:
epoch += 1
epoch_start_time = time.time()
model.train()
total_train_loss = 0
total_tokens_processed = 0
optimizer.zero_grad()
data_load_start = time.time()
for i, batch in enumerate(synchronized_dataloader_step(train_loader, is_dist())):
is_update_step = (i + 1) % train_cfg.gradient_accumulation_steps == 0 or i + 1 == len(train_loader)
batch_start_time = time.time()
images = batch["images"]
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
attention_mask = batch["attention_mask"].to(device)
data_load_time = time.time() - data_load_start
# When using DDP with gradient accumulation,
# skip gradient synchronization on intermediate steps to save time.
# Gradients only need to be synced at the end of each accumulation cycle.
if (is_dist()
and train_cfg.gradient_accumulation_steps > 1
and not is_update_step):
context = model.no_sync()
else:
context = contextlib.nullcontext()
fw_bw_start = time.time()
autocast_context = torch.autocast(
device_type=device.type,
dtype=torch.bfloat16 if device.type in ['cuda', 'cpu'] else torch.float16
)
with autocast_context:
with context:
_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)
if train_cfg.gradient_accumulation_steps > 1:
loss = loss / train_cfg.gradient_accumulation_steps
loss.backward()
fw_bw_time = time.time() - fw_bw_start
post_process_start = time.time()
if is_update_step:
if train_cfg.max_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(all_params, max_norm=train_cfg.max_grad_norm)
adj_lr_mp = get_lr(global_step, train_cfg.lr_mp, train_cfg.max_training_steps)
adj_lr_backbones = get_lr(global_step, train_cfg.lr_backbones, train_cfg.max_training_steps)
optimizer.param_groups[0]['lr'] = adj_lr_mp
optimizer.param_groups[1]['lr'] = adj_lr_backbones
optimizer.step()
optimizer.zero_grad()
batch_loss = loss.item()
if train_cfg.gradient_accumulation_steps > 1:
batch_loss = batch_loss * train_cfg.gradient_accumulation_steps
total_train_loss += batch_loss
num_tokens = torch.sum(attention_mask).item() # Sum of attention mask gives number of tokens
total_tokens_processed += num_tokens
post_process_time = time.time() - post_process_start
images_per_sample = [len(image_pack) for image_pack in images]
batch_end_time = time.time()
batch_duration = batch_end_time - batch_start_time
tokens_per_second = get_world_size() * num_tokens / batch_duration # Multiply by world size to get global tokens/s
# Accumulate training stats
accumulated_stats['tokens_per_second'].append(tokens_per_second)
accumulated_stats['data_load_time'].append(data_load_time)
accumulated_stats['fw_bw_time'].append(fw_bw_time)
accumulated_stats['post_process_time'].append(post_process_time)
accumulated_stats['images_per_sample'].extend(images_per_sample)
if train_cfg.eval_in_epochs and global_step % train_cfg.eval_interval == 0 and is_update_step:
model.eval()
if device == "cuda":
torch.cuda.empty_cache()
with torch.no_grad():
total_val_loss = 0
for batch in val_loader:
images = batch["images"]
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
attention_mask = batch["attention_mask"].to(device)
with autocast_context:
_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
avg_val_loss = mean(dist_gather(avg_val_loss)) if is_dist() else avg_val_loss
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
if is_master():
save_model = model.module if is_dist() else model # unwrap the model for saving if DDP
save_model.save_pretrained(save_directory=os.path.join(vlm_cfg.vlm_checkpoint_path, run_name))
lmms_results = {}
if train_cfg.use_lmms_eval:
from evaluation import cli_evaluate
eval_args = argparse.Namespace(
model=model.module if is_dist() else model,
tasks=train_cfg.lmms_eval_tasks,
limit=train_cfg.lmms_eval_limit,
batch_size=train_cfg.lmms_eval_batch_size,
process_with_media=True,
device=device,
)
# Evaluate using the CLI wrapper
eval_results = cli_evaluate(eval_args)
if is_master() and eval_results and "results" in eval_results[0]:
for task_name, task_results in eval_results[0]["results"].items():
for metric_name, metric_value in task_results.items():
if isinstance(metric_value, (int, float)):
lmms_results[f"{task_name}_{metric_name.split(',')[0]}"] = metric_value
if is_master():
print(f"Step: {global_step}, Val Loss: {avg_val_loss:.4f}, Tokens/s: {tokens_per_second:.2f}")
if train_cfg.log_wandb:
run.log({"val_loss": avg_val_loss, **{f"lmms_eval/{key}": value for key, value in lmms_results.items()}}, step=global_step)
model.train()
# Log training stats every N steps (ALL RANKS must participate in collective ops)
if global_step % train_cfg.stats_log_interval == 0 and len(accumulated_stats['tokens_per_second']) > 0 and is_update_step:
# ALL RANKS: Perform collective operations for training stats
stats = {}
for key in ['tokens_per_second', 'data_load_time', 'fw_bw_time', 'post_process_time', 'images_per_sample']:
if is_dist():
all_values = dist_gather(accumulated_stats[key])
all_values_flat = [item for sublist in all_values for item in sublist] # Flatten list of lists
stats[f'avg_{key}'] = mean(all_values_flat)
else:
stats[f'avg_{key}'] = mean(accumulated_stats[key])
for key in ['data_load_time', 'fw_bw_time', 'post_process_time', 'images_per_sample']:
if is_dist():
all_values = dist_gather(accumulated_stats[key])
all_values_flat = [item for sublist in all_values for item in sublist]
stats[f'max_{key}'] = max(all_values_flat)
else:
stats[f'max_{key}'] = max(accumulated_stats[key])
if is_dist():
all_images_values = dist_gather(accumulated_stats['images_per_sample'])
all_images_flat = [item for sublist in all_images_values for item in sublist]
stats['min_images_per_sample'] = min(all_images_flat)
else:
stats['min_images_per_sample'] = min(accumulated_stats['images_per_sample'])
# MASTER ONLY: Log to wandb
if train_cfg.log_wandb and is_master():
run.log({
**{f"training_stats/{key}": value for key, value in stats.items()},
}, step=global_step)
# ALL RANKS: Reset accumulators
for key in accumulated_stats:
accumulated_stats[key] = []
# Log batch loss
if is_update_step:
# ALL RANKS: gather loss from all ranks if DDP
if is_dist():
batch_loss_gathered = mean(dist_gather(batch_loss))
else:
batch_loss_gathered = batch_loss
# MASTER ONLY: Log to wandb
if train_cfg.log_wandb and is_master():
run.log({
"batch_loss": batch_loss_gathered,
**({"grad_norm": grad_norm} if train_cfg.max_grad_norm is not None else {})
}, step=global_step)
if is_update_step:
global_step += 1
if global_step >= train_cfg.max_training_steps:
break
data_load_start = time.time()
avg_train_loss = total_train_loss / len(train_loader)
# gather average batch loss from all ranks if DDP
avg_train_loss = mean(dist_gather(avg_train_loss)) if is_dist() else avg_train_loss
epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
epoch_times.append(epoch_duration)
# gather and sum total_tokens_processed across all ranks if DDP
total_tokens_processed = sum(dist_gather(total_tokens_processed)) if is_dist() else total_tokens_processed
epoch_tokens_per_second = total_tokens_processed / epoch_duration
if is_master():
if train_cfg.log_wandb:
run.log({"epoch_loss": avg_train_loss,
"epoch_duration": epoch_duration,
"epoch_tokens_per_second": epoch_tokens_per_second})
print(f"Epoch: {epoch}, Step: {global_step}/{train_cfg.max_training_steps}, Train Loss: {avg_train_loss:.4f} | Time: {epoch_duration:.2f}s | T/s: {epoch_tokens_per_second:.2f}")
# Summary Statistics
if is_master():
avg_epoch_time = sum(epoch_times) / len(epoch_times)
total_training_time = sum(epoch_times)
batch_size = int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)
total_samples_processed = batch_size * global_step
avg_time_per_sample = total_training_time / total_samples_processed
print(f"Average time per epoch: {avg_epoch_time:.2f}s")
print(f"Average time per sample: {avg_time_per_sample:.4f}s")
# Push the best model to the hub (Please set your user name in the config!)
if vlm_cfg.hf_repo_name is not None:
print("Training complete. Pushing model to Hugging Face Hub...")
hf_model = VisionLanguageModel.from_pretrained(os.path.join(vlm_cfg.vlm_checkpoint_path, run_name))
hf_model.push_to_hub(vlm_cfg.hf_repo_name)
if train_cfg.log_wandb:
run.summary["avg_epoch_time"] = avg_epoch_time
run.summary["avg_time_per_sample"] = avg_time_per_sample
run.summary["mmstar_acc"] = best_accuracy
run.finish()