import csv
from argparse import ArgumentParser

import torch
from diffusers import (
    AutoencoderKL,
    AutoPipelineForText2Image,
    LatentConsistencyModelPipeline,
    LCMScheduler,
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
    UNet2DConditionModel,
)
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
from torch.utils.benchmark import Compare, Timer
from transformers import (
    AutoTokenizer,
    CLIPTextModel,
    CLIPTextModelWithProjection,
    CLIPTokenizer,
)

from muse import MaskGiTUViT, PipelineMuse, VQGANModel

torch.manual_seed(0)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

num_threads = torch.get_num_threads()
prompt = "A high tech solarpunk utopia in the Amazon rainforest"

do_sd15       = True 
do_sdxl       = True 
do_ssd_1b     = True 
do_sdxl_turbo = True
do_sd_turbo   = True
do_muse       = True 
do_wurst      = True 
do_lcm        = True 

def main():
    args = ArgumentParser()
    args.add_argument("--device", choices=["4090", "a100"], required=True)

    args = args.parse_args()
    csv_data = []

    for batch_size in [1, 8]:
        for timesteps in [12, 20]:
            for use_xformers in [False, True]:
                if do_sd15:
                    out, mem_bytes = sd_benchmark(batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers)

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "stable_diffusion_1_5",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            512,
                            use_xformers,
                            None,
                        ]
                    )

                if do_sdxl:
                    out, mem_bytes = sdxl_benchmark(
                        batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers, gpu_type=args.device
                    )

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "sdxl",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            1024,
                            use_xformers,
                            None,
                        ]
                    )

                if do_ssd_1b:
                    out, mem_bytes = ssd_1b_benchmark(
                        batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers, gpu_type=args.device
                    )

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "ssd_1b",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            1024,
                            use_xformers,
                            None,
                        ]
                    )

                if do_muse:
                    for resolution in [256, 512]:
                        for use_fused_residual_norm in [False, True]:
                            out, mem_bytes = muse_benchmark(
                                resolution=resolution,
                                batch_size=batch_size,
                                timesteps=timesteps,
                                use_xformers=use_xformers,
                                use_fused_residual_norm=use_fused_residual_norm,
                            )

                            Compare([out]).print()
                            print("*******")

                            csv_data.append(
                                [
                                    batch_size,
                                    "muse",
                                    out.median * 1000,
                                    args.device,
                                    timesteps,
                                    mem_bytes,
                                    resolution,
                                    use_xformers,
                                    use_fused_residual_norm,
                                ]
                            )

        if do_sdxl_turbo:
            for use_xformers in [True, False]:
                timesteps = 1

                out, mem_bytes = sdxl_turbo_benchmark(
                    batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers
                )

                Compare([out]).print()
                print("*******")

                csv_data.append(
                    [
                        batch_size,
                        "sdxl_turbo",
                        out.median * 1000,
                        args.device,
                        timesteps,
                        mem_bytes,
                        1024,
                        use_xformers,
                        None,
                    ]
                )

        if do_sd_turbo:
            for use_xformers in [True, False]:
                timesteps = 1

                out, mem_bytes = sd_turbo_benchmark(
                    batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers
                )

                Compare([out]).print()
                print("*******")

                csv_data.append(
                    [
                        batch_size,
                        "sd_turbo",
                        out.median * 1000,
                        args.device,
                        timesteps,
                        mem_bytes,
                        512,
                        use_xformers,
                        None,
                    ]
                )

        if do_wurst:
            for use_xformers in [False, True]:
                out, mem_bytes = wurst_benchmark(batch_size, use_xformers)

                Compare([out]).print()
                print("*******")

                csv_data.append(
                    [
                        batch_size,
                        "wurst",
                        out.median * 1000,
                        args.device,
                        "default",
                        mem_bytes,
                        1024,
                        use_xformers,
                        None,
                    ]
                )

        if do_lcm:
            for timesteps in [4, 8]:
                for use_xformers in [False, True]:
                    out, mem_bytes = lcm_benchmark(batch_size, timesteps, use_xformers)

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "lcm",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            1024,
                            use_xformers,
                            None,
                        ]
                    )

    with open("benchmark/artifacts/all.csv", "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(csv_data)


def muse_benchmark(resolution, batch_size, timesteps, use_xformers, use_fused_residual_norm):
    model = "williamberman/muse_research_run_benchmarking_512_output"
    device = "cuda"
    dtype = torch.float16

    tokenizer = AutoTokenizer.from_pretrained(model, subfolder="text_encoder")

    text_encoder = CLIPTextModelWithProjection.from_pretrained(model, subfolder="text_encoder")
    text_encoder.to(device=device, dtype=dtype)

    vae = VQGANModel.from_pretrained(model, subfolder="vae")
    vae.to(device=device, dtype=dtype)

    transformer = MaskGiTUViT(
        use_fused_mlp=False,
        use_fused_residual_norm=use_fused_residual_norm,
        force_down_up_sample=resolution == 512,
    )
    transformer = transformer.to(device=device, dtype=dtype)
    transformer.eval()

    if use_xformers:
        transformer.enable_xformers_memory_efficient_attention()

    pipe = PipelineMuse(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        vae=vae,
        transformer=transformer,
    )
    pipe.device = device
    pipe.dtype = dtype

    seq_len = (resolution // 16) ** 2

    def benchmark_fn():
        pipe(prompt, num_images_per_prompt=batch_size, timesteps=timesteps, transformer_seq_len=seq_len)

    pipe(prompt, num_images_per_prompt=batch_size, timesteps=2, transformer_seq_len=seq_len)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=(
                f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, resolution: {resolution},"
                f" use_xformers: {use_xformers}, use_fused_residual_norm: {use_fused_residual_norm}"
            ),
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)


def wurst_benchmark(batch_size, use_xformers):
    model = "warp-ai/wuerstchen"
    device = "cuda"
    dtype = torch.float16

    pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=dtype).to(device)

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    def benchmark_fn():
        pipe(
            prompt,
            height=1024,
            width=1024,
            prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
            prior_guidance_scale=4.0,
            num_images_per_prompt=batch_size,
        )

    # warmup
    benchmark_fn()

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)


def sdxl_benchmark(batch_size, timesteps, use_xformers, gpu_type):
    model = "stabilityai/stable-diffusion-xl-base-1.0"
    device = "cuda"
    dtype = torch.float16

    pipe = StableDiffusionXLPipeline.from_pretrained(model, torch_dtype=dtype)
    pipe = pipe.to(device)

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    if gpu_type == "4090" and batch_size == 8:
        output_type = "latent"
    else:
        output_type = "pil"

    def benchmark_fn():
        pipe(prompt, num_inference_steps=timesteps, num_images_per_prompt=batch_size, output_type=output_type)

    pipe(prompt, num_inference_steps=2, num_images_per_prompt=batch_size, output_type=output_type)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)


def lcm_benchmark(batch_size, timesteps, use_xformers):
    model = "SimianLuo/LCM_Dreamshaper_v7"
    device = "cuda"
    dtype = torch.float16

    scheduler = LCMScheduler.from_pretrained(model, subfolder="scheduler")

    pipe = LatentConsistencyModelPipeline.from_pretrained(model, torch_dtype=dtype, scheduler=scheduler)
    pipe.to(device)

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    def benchmark_fn():
        pipe(prompt, num_inference_steps=timesteps, num_images_per_prompt=batch_size)

    pipe(prompt, num_inference_steps=2, num_images_per_prompt=batch_size)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)


def ssd_1b_benchmark(batch_size, timesteps, use_xformers, gpu_type):
    model = "segmind/SSD-1B"
    device = "cuda"
    dtype = torch.float16

    pipe = StableDiffusionXLPipeline.from_pretrained(model, torch_dtype=dtype, use_safetensors=True, variant="fp16")
    pipe.to(device)

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    if gpu_type == "4090" and batch_size == 8:
        output_type = "latent"
    else:
        output_type = "pil"

    def benchmark_fn():
        pipe(prompt, num_inference_steps=timesteps, num_images_per_prompt=batch_size, output_type=output_type)

    pipe(prompt, num_inference_steps=2, num_images_per_prompt=batch_size, output_type=output_type)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)


def sd_benchmark(batch_size, timesteps, use_xformers):
    model = "runwayml/stable-diffusion-v1-5"
    device = "cuda"
    dtype = torch.float16

    tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder")
    text_encoder.to(device=device, dtype=dtype)

    vae = AutoencoderKL.from_pretrained(model, subfolder="vae")
    vae = vae.to(device=device, dtype=dtype)

    unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet")
    unet = unet.to(device=device, dtype=dtype)

    pipe = StableDiffusionPipeline.from_pretrained(
        model,
        vae=vae,
        unet=unet,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        safety_checker=None,
    )

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    def benchmark_fn():
        pipe(
            prompt,
            num_images_per_prompt=batch_size,
            num_inference_steps=timesteps,
        )

    pipe(prompt, num_images_per_prompt=batch_size, num_inference_steps=2)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)

def sd_turbo_benchmark(batch_size, timesteps, use_xformers):
    model = "stabilityai/sd-turbo"
    dtype = torch.float16
    pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
    pipe.to("cuda")

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    def benchmark_fn():
        pipe(
            prompt,
            num_images_per_prompt=batch_size,
            num_inference_steps=timesteps,
        )

    pipe(prompt, num_images_per_prompt=batch_size, num_inference_steps=2)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)

def sdxl_turbo_benchmark(batch_size, timesteps, use_xformers):
    model = "stabilityai/sdxl-turbo"
    dtype = torch.float16
    pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
    pipe.to("cuda")

    if use_xformers:
        pipe.enable_xformers_memory_efficient_attention()

    def benchmark_fn():
        pipe(
            prompt,
            num_images_per_prompt=batch_size,
            num_inference_steps=timesteps,
        )

    pipe(prompt, num_images_per_prompt=batch_size, num_inference_steps=2)

    def fn():
        return Timer(
            stmt="benchmark_fn()",
            globals={"benchmark_fn": benchmark_fn},
            num_threads=num_threads,
            label=f"batch_size: {batch_size}, dtype: {dtype}, timesteps {timesteps}, use_xformers: {use_xformers}",
            description=model,
        ).blocked_autorange(min_run_time=1)

    return measure_max_memory_allocated(fn)


def measure_max_memory_allocated(fn):
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

    rv = fn()

    mem_bytes = torch.cuda.max_memory_allocated()

    return rv, mem_bytes


if __name__ == "__main__":
    main()
