def muse_benchmark()

in benchmark/muse_perf.py [0:0]


def muse_benchmark(resolution, batch_size, timesteps, use_xformers, use_fused_residual_norm):
    model = "williamberman/muse_research_run_benchmarking_512_output"
    device = "cuda"
    dtype = torch.float16

    tokenizer = AutoTokenizer.from_pretrained(model, subfolder="text_encoder")

    text_encoder = CLIPTextModelWithProjection.from_pretrained(model, subfolder="text_encoder")
    text_encoder.to(device=device, dtype=dtype)

    vae = VQGANModel.from_pretrained(model, subfolder="vae")
    vae.to(device=device, dtype=dtype)

    transformer = MaskGiTUViT(
        use_fused_mlp=False,
        use_fused_residual_norm=use_fused_residual_norm,
        force_down_up_sample=resolution == 512,
    )
    transformer = transformer.to(device=device, dtype=dtype)
    transformer.eval()

    if use_xformers:
        transformer.enable_xformers_memory_efficient_attention()

    pipe = PipelineMuse(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        vae=vae,
        transformer=transformer,
    )
    pipe.device = device
    pipe.dtype = dtype

    seq_len = (resolution // 16) ** 2

    def benchmark_fn():
        pipe(prompt, num_images_per_prompt=batch_size, timesteps=timesteps, transformer_seq_len=seq_len)

    pipe(prompt, num_images_per_prompt=batch_size, timesteps=2, transformer_seq_len=seq_len)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=(
                f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, resolution: {resolution},"
                f" use_xformers: {use_xformers}, use_fused_residual_norm: {use_fused_residual_norm}"
            ),
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)