def main()

in benchmark/muse_perf.py [0:0]


def main():
    args = ArgumentParser()
    args.add_argument("--device", choices=["4090", "a100"], required=True)

    args = args.parse_args()
    csv_data = []

    for batch_size in [1, 8]:
        for timesteps in [12, 20]:
            for use_xformers in [False, True]:
                if do_sd15:
                    out, mem_bytes = sd_benchmark(batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers)

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "stable_diffusion_1_5",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            512,
                            use_xformers,
                            None,
                        ]
                    )

                if do_sdxl:
                    out, mem_bytes = sdxl_benchmark(
                        batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers, gpu_type=args.device
                    )

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "sdxl",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            1024,
                            use_xformers,
                            None,
                        ]
                    )

                if do_ssd_1b:
                    out, mem_bytes = ssd_1b_benchmark(
                        batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers, gpu_type=args.device
                    )

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "ssd_1b",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            1024,
                            use_xformers,
                            None,
                        ]
                    )

                if do_muse:
                    for resolution in [256, 512]:
                        for use_fused_residual_norm in [False, True]:
                            out, mem_bytes = muse_benchmark(
                                resolution=resolution,
                                batch_size=batch_size,
                                timesteps=timesteps,
                                use_xformers=use_xformers,
                                use_fused_residual_norm=use_fused_residual_norm,
                            )

                            Compare([out]).print()
                            print("*******")

                            csv_data.append(
                                [
                                    batch_size,
                                    "muse",
                                    out.median * 1000,
                                    args.device,
                                    timesteps,
                                    mem_bytes,
                                    resolution,
                                    use_xformers,
                                    use_fused_residual_norm,
                                ]
                            )

        if do_sdxl_turbo:
            for use_xformers in [True, False]:
                timesteps = 1

                out, mem_bytes = sdxl_turbo_benchmark(
                    batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers
                )

                Compare([out]).print()
                print("*******")

                csv_data.append(
                    [
                        batch_size,
                        "sdxl_turbo",
                        out.median * 1000,
                        args.device,
                        timesteps,
                        mem_bytes,
                        1024,
                        use_xformers,
                        None,
                    ]
                )

        if do_sd_turbo:
            for use_xformers in [True, False]:
                timesteps = 1

                out, mem_bytes = sd_turbo_benchmark(
                    batch_size=batch_size, timesteps=timesteps, use_xformers=use_xformers
                )

                Compare([out]).print()
                print("*******")

                csv_data.append(
                    [
                        batch_size,
                        "sd_turbo",
                        out.median * 1000,
                        args.device,
                        timesteps,
                        mem_bytes,
                        512,
                        use_xformers,
                        None,
                    ]
                )

        if do_wurst:
            for use_xformers in [False, True]:
                out, mem_bytes = wurst_benchmark(batch_size, use_xformers)

                Compare([out]).print()
                print("*******")

                csv_data.append(
                    [
                        batch_size,
                        "wurst",
                        out.median * 1000,
                        args.device,
                        "default",
                        mem_bytes,
                        1024,
                        use_xformers,
                        None,
                    ]
                )

        if do_lcm:
            for timesteps in [4, 8]:
                for use_xformers in [False, True]:
                    out, mem_bytes = lcm_benchmark(batch_size, timesteps, use_xformers)

                    Compare([out]).print()
                    print("*******")

                    csv_data.append(
                        [
                            batch_size,
                            "lcm",
                            out.median * 1000,
                            args.device,
                            timesteps,
                            mem_bytes,
                            1024,
                            use_xformers,
                            None,
                        ]
                    )

    with open("benchmark/artifacts/all.csv", "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(csv_data)