def main()

in run_benchmark_pixart.py [0:0]


def main(args) -> dict:
    pipeline = load_pipeline(
        ckpt=args.ckpt,
        compile_transformer=args.compile_transformer,
        compile_vae=args.compile_vae,
        no_sdpa=args.no_sdpa,
        no_bf16=args.no_bf16,
        enable_fused_projections=args.enable_fused_projections,
        do_quant=args.do_quant,
        compile_mode=args.compile_mode,
        change_comp_config=args.change_comp_config,
        device=args.device,
    )

    # Warmup.
    run_inference(pipeline, args)
    run_inference(pipeline, args)
    run_inference(pipeline, args)

    time = benchmark_fn(run_inference, pipeline, args)  # in seconds.

    data_dict = generate_csv_dict(
        pipeline_cls=str(pipeline.__class__.__name__),
        args=args,
        time=time,
    )
    img = pipeline(
        prompt=args.prompt,
        num_inference_steps=args.num_inference_steps,
        num_images_per_prompt=args.batch_size,
    ).images[0]

    return data_dict, img