utils/pipeline_utils.py (284 lines of code) (raw):

import os import pathlib import torch import torch.nn.functional as F from diffusers import FluxPipeline from torch._inductor.package import load_package as inductor_load_package from typing import List, Optional, Tuple import inspect @torch.library.custom_op("flash::flash_attn_func", mutates_args=()) def flash_attn_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] =None, causal: bool = False, # probably wrong type for these 4 qv: Optional[float] = None, q_descale: Optional[float] = None, k_descale: Optional[float] = None, v_descale: Optional[float] = None, window_size: Optional[List[int]] = None, sink_token_length: int = 0, softcap: float = 0.0, num_splits: int = 1, # probably wrong type for this too pack_gqa: Optional[float] = None, deterministic: bool = False, sm_margin: int = 0, ) -> torch.Tensor: #Tuple[torch.Tensor, torch.Tensor]: if window_size is None: window_size = (-1, -1) else: window_size = tuple(window_size) import flash_attn_interface dtype = torch.float8_e4m3fn sig = inspect.signature(flash_attn_interface.flash_attn_func) accepted = set(sig.parameters) all_kwargs = { "softmax_scale": softmax_scale, "causal": causal, "qv": qv, "q_descale": q_descale, "k_descale": k_descale, "v_descale": v_descale, "window_size": window_size, "sink_token_length": sink_token_length, "softcap": softcap, "num_splits": num_splits, "pack_gqa": pack_gqa, "deterministic": deterministic, "sm_margin": sm_margin, } kwargs = {k: v for k, v in all_kwargs.items() if k in accepted} outputs = flash_attn_interface.flash_attn_func( q.to(dtype), k.to(dtype), v.to(dtype), **kwargs, ) return outputs[0] @flash_attn_func.register_fake def _(q, k, v, **kwargs): # two outputs: # 1. output: (batch, seq_len, num_heads, head_dim) # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32 meta_q = torch.empty_like(q).contiguous() return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32) # Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA class FlashFusedFluxAttnProcessor3_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): try: import flash_attn_interface except ImportError: raise ImportError( "flash_attention v3 package is required to be installed" ) def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` # `context` projections. if encoder_hidden_states is not None: encoder_qkv = attn.to_added_qkv(encoder_hidden_states) split_size = encoder_qkv.shape[-1] // 3 ( encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj, ) = torch.split(encoder_qkv, split_size, dim=-1) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # NB: transposes are necessary to match expected SDPA input shape hidden_states = flash_attn_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2))[0].transpose(1, 2) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: return hidden_states # wrapper to automatically handle CUDAGraph record / replay over the given function def cudagraph(f): from torch.utils._pytree import tree_map_only _graphs = {} def f_(*args, **kwargs): key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys()) if isinstance(kwargs[a], torch.Tensor))) if key in _graphs: # use the cached wrapper if one exists. this will perform CUDAGraph replay wrapped, *_ = _graphs[key] return wrapped(*args, **kwargs) # record a new CUDAGraph and cache it for future use g = torch.cuda.CUDAGraph() in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs)) f(*in_args, **in_kwargs) # stream warmup with torch.cuda.graph(g): out_tensors = f(*in_args, **in_kwargs) def wrapped(*args, **kwargs): # note that CUDAGraphs require inputs / outputs to be in fixed memory locations. # inputs must be copied into the fixed input memory locations. [a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)] for key in kwargs: if isinstance(kwargs[key], torch.Tensor): in_kwargs[key].copy_(kwargs[key]) g.replay() # clone() outputs on the way out to disconnect them from the fixed output memory # locations. this allows for CUDAGraph reuse without accidentally overwriting memory return [o.clone() for o in out_tensors] # cache function that does CUDAGraph replay _graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors) return wrapped(*args, **kwargs) return f_ def use_compile(pipeline): # Compile the compute-intensive portions of the model: denoising transformer / decoder pipeline.transformer = torch.compile( pipeline.transformer, mode="max-autotune", fullgraph=True ) pipeline.vae.decode = torch.compile( pipeline.vae.decode, mode="max-autotune", fullgraph=True ) # warmup for a few iterations (`num_inference_steps` shouldn't matter) for _ in range(3): pipeline( "dummy prompt to trigger torch compilation", output_type="pil", num_inference_steps=4, ).images[0] return pipeline def download_hosted_file(filename, output_path): # Download hosted binaries from huggingface Hub. from huggingface_hub import hf_hub_download REPO_NAME = "jbschlosser/flux-fast" hf_hub_download(REPO_NAME, filename, local_dir=os.path.dirname(output_path)) def load_package(package_path): if not os.path.exists(package_path): download_hosted_file(os.path.basename(package_path), package_path) loaded_package = inductor_load_package(package_path, run_single_threaded=True) return loaded_package def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True): # create cache dir if needed pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True) def _example_tensor(*shape): return torch.randn(*shape, device="cuda", dtype=torch.bfloat16) # === Transformer compile / export === seq_length = 256 if is_timestep_distilled else 512 # these shapes are for 1024x1024 resolution. transformer_kwargs = { "hidden_states": _example_tensor(1, 4096, 64), "timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16), "guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16), "pooled_projections": _example_tensor(1, 768), "encoder_hidden_states": _example_tensor(1, seq_length, 4096), "txt_ids": _example_tensor(seq_length, 3), "img_ids": _example_tensor(4096, 3), "joint_attention_kwargs": {}, "return_dict": False, } # Possibly serialize model out transformer_package_path = os.path.join( cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2" ) if serialize: # Apply export exported_transformer: torch.export.ExportedProgram = torch.export.export( pipeline.transformer, args=(), kwargs=transformer_kwargs ) # Apply AOTI path = torch._inductor.aoti_compile_and_package( exported_transformer, package_path=transformer_package_path, inductor_configs={"max_autotune": True, "triton.cudagraphs": True}, ) # download serialized model if needed loaded_transformer = load_package(transformer_package_path) # warmup before cudagraphing with torch.no_grad(): loaded_transformer(**transformer_kwargs) # Apply CUDAGraphs. CUDAGraphs are utilized in torch.compile with mode="max-autotune", but # they must be manually applied for torch.export + AOTI. loaded_transformer = cudagraph(loaded_transformer) pipeline.transformer.forward = loaded_transformer # warmup after cudagraphing with torch.no_grad(): pipeline.transformer(**transformer_kwargs) # hack to get around export's limitations pipeline.vae.forward = pipeline.vae.decode vae_decode_kwargs = {"return_dict": False} # Possibly serialize model out decoder_package_path = os.path.join( cache_dir, "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt2" ) if serialize: # Apply export exported_decoder: torch.export.ExportedProgram = torch.export.export( pipeline.vae, args=(_example_tensor(1, 16, 128, 128),), kwargs=vae_decode_kwargs ) # Apply AOTI path = torch._inductor.aoti_compile_and_package( exported_decoder, package_path=decoder_package_path, inductor_configs={"max_autotune": True, "triton.cudagraphs": True}, ) # download serialized model if needed loaded_decoder = load_package(decoder_package_path) # warmup before cudagraphing with torch.no_grad(): loaded_decoder(_example_tensor(1, 16, 128, 128), **vae_decode_kwargs) loaded_decoder = cudagraph(loaded_decoder) pipeline.vae.decode = loaded_decoder # warmup for a few iterations for _ in range(3): pipeline( "dummy prompt to trigger torch compilation", output_type="pil", num_inference_steps=4, ).images[0] return pipeline def optimize(pipeline, args): is_timestep_distilled = not pipeline.transformer.config.guidance_embeds # fuse QKV projections in Transformer and VAE if not args.disable_fused_projections: pipeline.transformer.fuse_qkv_projections() pipeline.vae.fuse_qkv_projections() # Use flash attention v3 if not args.disable_fa3: pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0()) # switch memory layout to channels_last if not args.disable_channels_last: pipeline.vae.to(memory_format=torch.channels_last) # apply float8 quantization if not args.disable_quant: from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight #, PerRow quantize_( pipeline.transformer, float8_dynamic_activation_float8_weight(), # float8_dynamic_activation_float8_weight(granularity=PerRow()), ) # set inductor flags if not args.disable_inductor_tuning_flags: config = torch._inductor.config config.conv_1x1_as_mm = True # treat 1x1 convolutions as matrix muls config.epilogue_fusion = False # do not fuse pointwise ops into matmuls # adjust autotuning algorithm config.coordinate_descent_tuning = True config.coordinate_descent_check_all_directions = True # TODO: Test out more mm settings # config.triton.enable_persistent_tma_matmul = True # config.max_autotune_gemm_backends = "ATEN,TRITON,CPP,CUTLASS" if args.compile_export_mode == "compile": pipeline = use_compile(pipeline) elif args.compile_export_mode == "export_aoti": pipeline = use_export_aoti( pipeline, cache_dir=args.cache_dir, serialize=(not args.use_cached_model), is_timestep_distilled=is_timestep_distilled ) elif args.compile_export_mode == "disabled": pass else: raise RuntimeError( "expected compile_export_mode arg to be one of {compile, export_aoti, disabled}" ) return pipeline def load_pipeline(args): load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16 pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device) pipeline.set_progress_bar_config(disable=True) pipeline = optimize(pipeline, args) return pipeline