def main()

in run_profile.py [0:0]


def main(args) -> dict:
    pipeline = load_pipeline(
        ckpt=args.ckpt,
        compile_unet=args.compile_unet,
        compile_vae=args.compile_vae,
        no_sdpa=args.no_sdpa,
        no_bf16=args.no_bf16,
        upcast_vae=args.upcast_vae,
        enable_fused_projections=args.enable_fused_projections,
        do_quant=args.do_quant,
        compile_mode=args.compile_mode,
        change_comp_config=args.change_comp_config,
        device=args.device,
    )

    # warmup.
    run_inference(pipeline, args)
    run_inference(pipeline, args)

    trace_path = (
        args.ckpt.replace("/", "_")
        + f"bf16@{not args.no_bf16}-sdpa@{not args.no_sdpa}-bs@{args.batch_size}-fuse@{args.enable_fused_projections}-upcast_vae@{args.upcast_vae}-steps@{args.num_inference_steps}-unet@{args.compile_unet}-vae@{args.compile_vae}-mode@{args.compile_mode}-change_comp_config@{args.change_comp_config}-do_quant@{args.do_quant}-device@{args.device}.json"
    )
    runner = functools.partial(profiler_runner, trace_path)
    with torch.autograd.profiler.record_function("sdxl-brrr"):
        runner(run_inference, pipeline, args)
    return trace_path