def main()

in eval/measure_vram.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Measure VRAM usage for a VisionLanguageModel at different batch sizes.")
    
    # Model and Config args
    parser.add_argument('--compile', action='store_true', help='Compile the model with torch.compile.')

    # Measurement control args
    parser.add_argument('--batch_sizes', type=str, default="1 2 4 8 16 32 64 128 256 512", help='Space-separated list of batch sizes to test (e.g., "1 2 4 8").')
    parser.add_argument('--lm_max_length', type=int, default=128, help='Maximum length of the input sequence for the language model.')
    parser.add_argument('--lm_model_type', type=str, default='HuggingFaceTB/SmolLM2-135M-Instruct', help='Model type for the language model.')
    parser.add_argument('--num_iterations', type=int, default=2, help='Number of forward/backward passes per batch size for VRAM measurement.')

    args = parser.parse_args()

    vlm_cfg = config.VLMConfig(lm_max_length=args.lm_max_length, lm_model_type=args.lm_model_type)
    train_cfg_defaults = config.TrainConfig() # Used for default dataset path/name if not provided by CLI

    print("--- VLM Config (from models.config) ---")
    print(vlm_cfg) # Show base config
    print("--- Train Config Defaults (for dataset path/name if not specified via CLI) ---")
    print(f"Default dataset_path: {train_cfg_defaults.train_dataset_path}")
    print(f"Default dataset_name list: {train_cfg_defaults.train_dataset_name}")
    
    measure_vram(args, vlm_cfg, train_cfg_defaults)