utils/benchmark_utils.py (46 lines of code) (raw):
import argparse
import functools
import os
from torch.profiler import record_function
def create_parser():
"""Creates CLI args parser."""
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# general options
parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell",
help="Model checkpoint path")
parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn",
help="Text prompt")
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
help="Cache directory for storing exported models")
parser.add_argument("--use-cached-model", action="store_true",
help="Attempt to use cached model only (don't re-export)")
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda",
help="Device to use")
parser.add_argument("--num_inference_steps", type=int, default=4,
help="Number of denoising steps")
parser.add_argument("--output-file", type=str, default="output.png",
help="Output image file path")
parser.add_argument("--seed", type=int, default=42,
help="Random seed to use")
# file path for optional output PyTorch Profiler trace
parser.add_argument("--trace-file", type=str, default=None,
help="Output PyTorch Profiler trace file path")
# optimizations - all are on by default but each can be disabled
parser.add_argument("--disable_bf16", action="store_true",
help="Disables usage of torch.bfloat16")
# torch.compile OR torch.export + AOTI OR neither
parser.add_argument("--compile_export_mode", type=str, default="export_aoti",
choices=["compile", "export_aoti", "disabled"],
help="Configures how torch.compile or torch.export + AOTI are used")
# fused (q, k, v) projections
parser.add_argument("--disable_fused_projections", action="store_true",
help="Disables fused q,k,v projections")
# channels_last memory format
parser.add_argument("--disable_channels_last", action="store_true",
help="Disables usage of torch.channels_last memory format")
# Flash Attention v3
parser.add_argument("--disable_fa3", action="store_true",
help="Disables use of Flash Attention V3")
# dynamic float8 quantization
parser.add_argument("--disable_quant", action="store_true",
help="Disables usage of dynamic float8 quantization")
# flags for tuning inductor
parser.add_argument("--disable_inductor_tuning_flags", action="store_true",
help="Disables use of inductor tuning flags")
return parser
# helper to annotate a function within a profiler trace
def annotate(f, title):
@functools.wraps(f)
def _f(*args, **kwargs):
with record_function(title):
return f(*args, **kwargs)
return _f