in run_profile.py [0:0]
def main(args) -> dict:
pipeline = load_pipeline(
ckpt=args.ckpt,
compile_unet=args.compile_unet,
compile_vae=args.compile_vae,
no_sdpa=args.no_sdpa,
no_bf16=args.no_bf16,
upcast_vae=args.upcast_vae,
enable_fused_projections=args.enable_fused_projections,
do_quant=args.do_quant,
compile_mode=args.compile_mode,
change_comp_config=args.change_comp_config,
device=args.device,
)
# warmup.
run_inference(pipeline, args)
run_inference(pipeline, args)
trace_path = (
args.ckpt.replace("/", "_")
+ f"bf16@{not args.no_bf16}-sdpa@{not args.no_sdpa}-bs@{args.batch_size}-fuse@{args.enable_fused_projections}-upcast_vae@{args.upcast_vae}-steps@{args.num_inference_steps}-unet@{args.compile_unet}-vae@{args.compile_vae}-mode@{args.compile_mode}-change_comp_config@{args.change_comp_config}-do_quant@{args.do_quant}-device@{args.device}.json"
)
runner = functools.partial(profiler_runner, trace_path)
with torch.autograd.profiler.record_function("sdxl-brrr"):
runner(run_inference, pipeline, args)
return trace_path