utils/pipeline_utils_pixart.py (113 lines of code) (raw):
import torch
from torchao.quantization import (
apply_dynamic_quant,
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
swap_conv2d_1x1_to_linear,
)
from diffusers import DiffusionPipeline
def dynamic_quant_filter_fn(mod, *args):
return (
isinstance(mod, torch.nn.Linear)
and mod.in_features > 16
and (mod.in_features, mod.out_features)
not in [
(1280, 640),
(1920, 1280),
(1920, 640),
(2048, 1280),
(2048, 2560),
(2560, 1280),
(256, 128),
(2816, 1280),
(320, 640),
(512, 1536),
(512, 256),
(512, 512),
(640, 1280),
(640, 1920),
(640, 320),
(640, 5120),
(640, 640),
(960, 320),
(960, 640),
]
)
def conv_filter_fn(mod, *args):
return (
isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
)
def load_pipeline(
ckpt: str,
compile_transformer: bool,
compile_vae: bool,
no_sdpa: bool,
no_bf16: bool,
enable_fused_projections: bool,
do_quant: bool,
compile_mode: str,
change_comp_config: bool,
device: str,
):
"""Loads the PixArt-Alpha pipeline."""
if do_quant and not compile_transformer:
raise ValueError("Compilation for Transformer must be enabled when quantizing.")
if do_quant and not compile_vae:
raise ValueError("Compilation for VAE must be enabled when quantizing.")
dtype = torch.float32 if no_bf16 else torch.bfloat16
print(f"Using dtype: {dtype}")
pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=dtype)
if enable_fused_projections:
print("Enabling fused QKV projections for both Transformer and VAE.")
pipe.fuse_qkv_projections()
if no_sdpa:
print("Using vanilla attention.")
pipe.transformer.set_default_attn_processor()
pipe.vae.set_default_attn_processor()
if device == "cuda":
pipe = pipe.to("cuda")
if compile_transformer:
pipe.transformer.to(memory_format=torch.channels_last)
print("Compile Transformer")
swap_conv2d_1x1_to_linear(pipe.transformer, conv_filter_fn)
if compile_mode == "max-autotune" and change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
if do_quant:
print("Apply quantization to Transformer")
if do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.transformer)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.transformer)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.transformer, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True)
if compile_vae:
pipe.vae.to(memory_format=torch.channels_last)
print("Compile VAE")
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)
if compile_mode == "max-autotune" and change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
if do_quant:
print("Apply quantization to VAE")
if do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.vae)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.vae)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
pipe.vae.decode = torch.compile(pipe.vae.decode, mode=compile_mode, fullgraph=True)
pipe.set_progress_bar_config(disable=True)
return pipe