scripts/benchmark_models.py (44 lines of code) (raw):
import argparse
from functools import partial
import torch
import torch.utils.benchmark as benchmark
from muse import MaskGitTransformer, MaskGiTUViT
def benchmark_torch_function(f, *args, **kwargs):
t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
return round(t0.blocked_autorange(min_run_time=1).mean, 2)
def create_model_and_benchmark(args):
if args.model_type == "transformer":
config = MaskGitTransformer.load_config(args.config_path)
model = MaskGitTransformer.from_config(config).to(args.device)
elif args.model_type == "uvit":
config = MaskGiTUViT.load_config(args.config_path)
model = MaskGiTUViT.from_config(config).to(args.device)
model.eval()
print("Running benchmark for vanilla attention in FP32 ...")
encoder_hidden_states = torch.randn(
args.batch_size, args.text_length, model.config.encoder_hidden_size, device=args.device, dtype=torch.float32
)
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_vanilla = benchmark_torch_function(f)
print("Running benchmark for vanilla attention in FP16 ...")
encoder_hidden_states = encoder_hidden_states.half()
model = model.half()
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_vanilla_fp16 = benchmark_torch_function(f)
print("Running benchmark for efficient attention in FP16 ...")
model.enable_xformers_memory_efficient_attention()
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_efficient_fp16 = benchmark_torch_function(f)
# print results with nice formatting
print(f"Vanilla attention in FP32: {time_vanilla} ms")
print(f"Vanilla attention in FP16: {time_vanilla_fp16} ms")
print(f"Efficient attention in FP16: {time_efficient_fp16} ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--model_type", type=str, default="transformer", choices=["transformer", "uvit"])
parser.add_argument("--text_length", type=int, default=96)
parser.add_argument("--time_steps", type=int, default=12)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
create_model_and_benchmark(args)