in utils/pipeline_utils.py [0:0]
def optimize(pipeline, args):
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
# fuse QKV projections in Transformer and VAE
if not args.disable_fused_projections:
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()
# Use flash attention v3
if not args.disable_fa3:
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
# switch memory layout to channels_last
if not args.disable_channels_last:
pipeline.vae.to(memory_format=torch.channels_last)
# apply float8 quantization
if not args.disable_quant:
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight #, PerRow
quantize_(
pipeline.transformer,
float8_dynamic_activation_float8_weight(),
# float8_dynamic_activation_float8_weight(granularity=PerRow()),
)
# set inductor flags
if not args.disable_inductor_tuning_flags:
config = torch._inductor.config
config.conv_1x1_as_mm = True # treat 1x1 convolutions as matrix muls
config.epilogue_fusion = False # do not fuse pointwise ops into matmuls
# adjust autotuning algorithm
config.coordinate_descent_tuning = True
config.coordinate_descent_check_all_directions = True
# TODO: Test out more mm settings
# config.triton.enable_persistent_tma_matmul = True
# config.max_autotune_gemm_backends = "ATEN,TRITON,CPP,CUTLASS"
if args.compile_export_mode == "compile":
pipeline = use_compile(pipeline)
elif args.compile_export_mode == "export_aoti":
pipeline = use_export_aoti(
pipeline,
cache_dir=args.cache_dir,
serialize=(not args.use_cached_model),
is_timestep_distilled=is_timestep_distilled
)
elif args.compile_export_mode == "disabled":
pass
else:
raise RuntimeError(
"expected compile_export_mode arg to be one of {compile, export_aoti, disabled}"
)
return pipeline