in scripts/benchmark_models.py [0:0]
def create_model_and_benchmark(args):
if args.model_type == "transformer":
config = MaskGitTransformer.load_config(args.config_path)
model = MaskGitTransformer.from_config(config).to(args.device)
elif args.model_type == "uvit":
config = MaskGiTUViT.load_config(args.config_path)
model = MaskGiTUViT.from_config(config).to(args.device)
model.eval()
print("Running benchmark for vanilla attention in FP32 ...")
encoder_hidden_states = torch.randn(
args.batch_size, args.text_length, model.config.encoder_hidden_size, device=args.device, dtype=torch.float32
)
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_vanilla = benchmark_torch_function(f)
print("Running benchmark for vanilla attention in FP16 ...")
encoder_hidden_states = encoder_hidden_states.half()
model = model.half()
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_vanilla_fp16 = benchmark_torch_function(f)
print("Running benchmark for efficient attention in FP16 ...")
model.enable_xformers_memory_efficient_attention()
f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps)
time_efficient_fp16 = benchmark_torch_function(f)
# print results with nice formatting
print(f"Vanilla attention in FP32: {time_vanilla} ms")
print(f"Vanilla attention in FP16: {time_vanilla_fp16} ms")
print(f"Efficient attention in FP16: {time_efficient_fp16} ms")