in utils/pipeline_utils.py [0:0]
def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True):
# create cache dir if needed
pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)
def _example_tensor(*shape):
return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
# === Transformer compile / export ===
seq_length = 256 if is_timestep_distilled else 512
# these shapes are for 1024x1024 resolution.
transformer_kwargs = {
"hidden_states": _example_tensor(1, 4096, 64),
"timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
"guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
"pooled_projections": _example_tensor(1, 768),
"encoder_hidden_states": _example_tensor(1, seq_length, 4096),
"txt_ids": _example_tensor(seq_length, 3),
"img_ids": _example_tensor(4096, 3),
"joint_attention_kwargs": {},
"return_dict": False,
}
# Possibly serialize model out
transformer_package_path = os.path.join(
cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
)
if serialize:
# Apply export
exported_transformer: torch.export.ExportedProgram = torch.export.export(
pipeline.transformer, args=(), kwargs=transformer_kwargs
)
# Apply AOTI
path = torch._inductor.aoti_compile_and_package(
exported_transformer,
package_path=transformer_package_path,
inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
)
# download serialized model if needed
loaded_transformer = load_package(transformer_package_path)
# warmup before cudagraphing
with torch.no_grad():
loaded_transformer(**transformer_kwargs)
# Apply CUDAGraphs. CUDAGraphs are utilized in torch.compile with mode="max-autotune", but
# they must be manually applied for torch.export + AOTI.
loaded_transformer = cudagraph(loaded_transformer)
pipeline.transformer.forward = loaded_transformer
# warmup after cudagraphing
with torch.no_grad():
pipeline.transformer(**transformer_kwargs)
# hack to get around export's limitations
pipeline.vae.forward = pipeline.vae.decode
vae_decode_kwargs = {"return_dict": False}
# Possibly serialize model out
decoder_package_path = os.path.join(
cache_dir, "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt2"
)
if serialize:
# Apply export
exported_decoder: torch.export.ExportedProgram = torch.export.export(
pipeline.vae, args=(_example_tensor(1, 16, 128, 128),), kwargs=vae_decode_kwargs
)
# Apply AOTI
path = torch._inductor.aoti_compile_and_package(
exported_decoder,
package_path=decoder_package_path,
inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
)
# download serialized model if needed
loaded_decoder = load_package(decoder_package_path)
# warmup before cudagraphing
with torch.no_grad():
loaded_decoder(_example_tensor(1, 16, 128, 128), **vae_decode_kwargs)
loaded_decoder = cudagraph(loaded_decoder)
pipeline.vae.decode = loaded_decoder
# warmup for a few iterations
for _ in range(3):
pipeline(
"dummy prompt to trigger torch compilation",
output_type="pil",
num_inference_steps=4,
).images[0]
return pipeline