in eval/measure_vram.py [0:0]
def measure_vram(args, vlm_cfg, train_cfg_defaults):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("CUDA not available. VRAM measurement requires a CUDA-enabled GPU.")
return
# --- Model Initialization ---
torch.cuda.reset_peak_memory_stats(device)
print(f"Using VLMConfig defaults: load_backbone_weights={vlm_cfg.vlm_load_backbone_weights}")
model = VisionLanguageModel(vlm_cfg, load_backbone=vlm_cfg.vlm_load_backbone_weights)
if args.compile:
print("Compiling the model with torch.compile...")
model = torch.compile(model)
print("Model compiled.")
model.to(device)
# Measure VRAM after model is loaded to device
torch.cuda.synchronize() # Ensure all operations are complete
initial_vram_allocated_bytes = torch.cuda.memory_allocated(device)
initial_vram_allocated_mb = initial_vram_allocated_bytes / (1024 ** 2)
print(f"VRAM allocated after loading model to device: {initial_vram_allocated_mb:.2f} MB")
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
# --- Dataset Preparation ---
image_processor = get_image_processor(vlm_cfg.vit_img_size)
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens)
dataset_path = train_cfg_defaults.train_dataset_path
# train_cfg_defaults.train_dataset_name is a list, use the first if not specified
dataset_name = train_cfg_defaults.train_dataset_name[0] if train_cfg_defaults.train_dataset_name else None
batch_sizes_to_test = [int(bs) for bs in args.batch_sizes.split()]
if not batch_sizes_to_test:
print("Error: No batch sizes provided or parsed correctly.")
return
num_iterations_for_vram = args.num_iterations
max_bs_to_test = max(batch_sizes_to_test)
required_samples_for_base_ds = max_bs_to_test * num_iterations_for_vram
try:
print(f"Loading dataset: {dataset_path}, name: {dataset_name}")
# Attempt to load only the 'train' split, adjust if dataset has different split names
available_splits = load_dataset(dataset_path, dataset_name).keys()
split_to_use = 'train' if 'train' in available_splits else list(available_splits)[0]
base_ds_full = load_dataset(dataset_path, dataset_name, split=split_to_use)
if len(base_ds_full) < required_samples_for_base_ds:
print(f"Warning: Dataset '{dataset_name}' (split: {split_to_use}) has {len(base_ds_full)} samples, "
f"but {required_samples_for_base_ds} are recommended for max batch size {max_bs_to_test} "
f"and {num_iterations_for_vram} iterations. Using all available samples.")
base_ds_for_vram_test = base_ds_full
else:
base_ds_for_vram_test = base_ds_full.select(range(required_samples_for_base_ds))
print(f"Using {len(base_ds_for_vram_test)} samples for VRAM testing.")
except Exception as e:
print(f"Error loading dataset: {dataset_path}, name: {dataset_name}. Error: {e}")
print("Please ensure the dataset path and name are correct.")
return
processed_base_dataset = VQADataset(base_ds_for_vram_test, tokenizer, image_processor)
vqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length, vlm_cfg.mp_image_token_length)
print("\n--- VRAM Measurement ---")
results = {}
for bs in batch_sizes_to_test:
print(f"\nTesting Batch Size: {bs}")
if len(processed_base_dataset) < bs:
print(f"Base processed dataset has {len(processed_base_dataset)} samples, "
f"not enough for batch size {bs}. Skipping.")
results[bs] = "Not enough data"
continue
current_loader = DataLoader(
processed_base_dataset,
batch_size=bs,
shuffle=False,
collate_fn=vqa_collator,
num_workers=0,
pin_memory=True,
drop_last=True # Important if dataset size is not exactly multiple of bs
)
if len(current_loader) < num_iterations_for_vram:
print(f"Dataloader for batch size {bs} yields {len(current_loader)} batches, "
f"less than requested {num_iterations_for_vram} iterations. Will run available batches.")
if len(current_loader) == 0:
print(f"Dataloader for batch size {bs} is empty. Skipping.")
results[bs] = "Dataloader empty"
continue
# Reset CUDA memory stats for each batch size test
torch.cuda.reset_peak_memory_stats(device)
# Model to train mode for realistic scenario (e.g. dropout layers active)
model.train()
optimizer = optim.AdamW(model.parameters(), lr=1e-5) # Dummy optimizer
try:
for i, batch in enumerate(current_loader):
if i >= num_iterations_for_vram:
break
images = batch["image"].to(device)
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
attention_mask = batch["attention_mask"].to(device)
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16): # Doing autocast to stay close the train.py script
_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)
if loss is not None:
loss.backward()
optimizer.step()
else:
print("Warning: Model did not return loss. Backward pass and optimizer step skipped. VRAM for these operations will not be measured.")
peak_vram_allocated_bytes = torch.cuda.max_memory_allocated(device)
peak_vram_allocated_mb = peak_vram_allocated_bytes / (1024 ** 2)
print(f"Peak VRAM allocated for batch size {bs}: {peak_vram_allocated_mb:.2f} MB")
results[bs] = f"{peak_vram_allocated_mb:.2f} MB"
except RuntimeError as e:
if "out of memory" in str(e).lower():
peak_vram_allocated_bytes = torch.cuda.max_memory_allocated(device) # Get max allocated before OOM
peak_vram_allocated_mb = peak_vram_allocated_bytes / (1024 ** 2)
print(f"CUDA out of memory for batch size {bs}. ")
print(f"Peak VRAM allocated before OOM: {peak_vram_allocated_mb:.2f} MB (may be approximate)")
results[bs] = f"OOM (Peak before OOM: {peak_vram_allocated_mb:.2f} MB)"
else:
print(f"An unexpected runtime error occurred for batch size {bs}: {e}")
results[bs] = f"Error: {e}"
# raise e # Optionally re-raise for debugging
finally:
del current_loader, optimizer
if 'loss' in locals() and loss is not None : del loss
if 'images' in locals(): del images
if 'input_ids' in locals(): del input_ids
if 'labels' in locals(): del labels
if 'attention_mask' in locals(): del attention_mask
torch.cuda.empty_cache()
print("\n--- Summary of VRAM Usage ---")
for bs, vram_usage in results.items():
print(f"Batch Size {bs}: {vram_usage}")