in run_benchmark_pixart.py [0:0]
def main(args) -> dict:
pipeline = load_pipeline(
ckpt=args.ckpt,
compile_transformer=args.compile_transformer,
compile_vae=args.compile_vae,
no_sdpa=args.no_sdpa,
no_bf16=args.no_bf16,
enable_fused_projections=args.enable_fused_projections,
do_quant=args.do_quant,
compile_mode=args.compile_mode,
change_comp_config=args.change_comp_config,
device=args.device,
)
# Warmup.
run_inference(pipeline, args)
run_inference(pipeline, args)
run_inference(pipeline, args)
time = benchmark_fn(run_inference, pipeline, args) # in seconds.
data_dict = generate_csv_dict(
pipeline_cls=str(pipeline.__class__.__name__),
args=args,
time=time,
)
img = pipeline(
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
).images[0]
return data_dict, img