utils/pipeline_utils.py (125 lines of code) (raw):
import torch
from torchao.quantization import (
apply_dynamic_quant,
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
swap_conv2d_1x1_to_linear,
)
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
PROMPT = "ghibli style, a fantasy landscape with castles"
def dynamic_quant_filter_fn(mod, *args):
return (
isinstance(mod, torch.nn.Linear)
and mod.in_features > 16
and (mod.in_features, mod.out_features)
not in [
(1280, 640),
(1920, 1280),
(1920, 640),
(2048, 1280),
(2048, 2560),
(2560, 1280),
(256, 128),
(2816, 1280),
(320, 640),
(512, 1536),
(512, 256),
(512, 512),
(640, 1280),
(640, 1920),
(640, 320),
(640, 5120),
(640, 640),
(960, 320),
(960, 640),
]
)
def conv_filter_fn(mod, *args):
return (
isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
)
def load_pipeline(
ckpt: str,
compile_unet: bool,
compile_vae: bool,
no_sdpa: bool,
no_bf16: bool,
upcast_vae: bool,
enable_fused_projections: bool,
do_quant: bool,
compile_mode: str,
change_comp_config: bool,
device: str,
):
"""Loads the SDXL pipeline."""
if do_quant and not compile_unet:
raise ValueError("Compilation for UNet must be enabled when quantizing.")
if do_quant and not compile_vae:
raise ValueError("Compilation for VAE must be enabled when quantizing.")
dtype = torch.float32 if no_bf16 else torch.bfloat16
print(f"Using dtype: {dtype}")
if ckpt != "runwayml/stable-diffusion-v1-5":
pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=dtype, use_safetensors=True)
else:
pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=dtype, use_safetensors=True, safety_checker=None)
# As the default scheduler of SD v1-5 doesn't have sigmas device placement
# (https://github.com/huggingface/diffusers/pull/6173)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if not upcast_vae and ckpt != "runwayml/stable-diffusion-v1-5":
print("Using a more numerically stable VAE.")
pipe.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
if enable_fused_projections:
print("Enabling fused QKV projections for both UNet and VAE.")
pipe.fuse_qkv_projections()
if upcast_vae and ckpt != "runwayml/stable-diffusion-v1-5":
print("Upcasting VAE.")
pipe.upcast_vae()
if no_sdpa:
print("Using vanilla attention.")
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()
if device == "cuda":
pipe = pipe.to("cuda")
if compile_unet:
pipe.unet.to(memory_format=torch.channels_last)
print("Compile UNet.")
swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
if compile_mode == "max-autotune" and change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
if do_quant:
print("Apply quantization to UNet.")
if do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.unet)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.unet)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
pipe.unet = torch.compile(pipe.unet, mode=compile_mode, fullgraph=True)
if compile_vae:
pipe.vae.to(memory_format=torch.channels_last)
print("Compile VAE.")
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)
if compile_mode == "max-autotune" and change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
if do_quant:
print("Apply quantization to VAE.")
if do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.vae)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.vae)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
pipe.vae.decode = torch.compile(pipe.vae.decode, mode=compile_mode, fullgraph=True)
pipe.set_progress_bar_config(disable=True)
return pipe