def wurst_benchmark()

in benchmark/muse_perf.py [0:0]


def wurst_benchmark(batch_size, use_xformers):
    model = "warp-ai/wuerstchen"
    device = "cuda"
    dtype = torch.float16

    pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=dtype).to(device)

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    def benchmark_fn():
        pipe(
            prompt,
            height=1024,
            width=1024,
            prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
            prior_guidance_scale=4.0,
            num_images_per_prompt=batch_size,
        )

    # warmup
    benchmark_fn()

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)