def use_export_aoti()

in utils/pipeline_utils.py [0:0]


def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True):
    # create cache dir if needed
    pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)

    def _example_tensor(*shape):
        return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)

    # === Transformer compile / export ===
    seq_length = 256 if is_timestep_distilled else 512
    # these shapes are for 1024x1024 resolution.
    transformer_kwargs = {
        "hidden_states": _example_tensor(1, 4096, 64),
        "timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
        "guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
        "pooled_projections": _example_tensor(1, 768),
        "encoder_hidden_states": _example_tensor(1, seq_length, 4096),
        "txt_ids": _example_tensor(seq_length, 3),
        "img_ids": _example_tensor(4096, 3),
        "joint_attention_kwargs": {},
        "return_dict": False,
    }

    # Possibly serialize model out
    transformer_package_path = os.path.join(
        cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
    )
    if serialize:
        # Apply export
        exported_transformer: torch.export.ExportedProgram = torch.export.export(
            pipeline.transformer, args=(), kwargs=transformer_kwargs
        )

        # Apply AOTI
        path = torch._inductor.aoti_compile_and_package(
            exported_transformer,
            package_path=transformer_package_path,
            inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
        )
    # download serialized model if needed
    loaded_transformer = load_package(transformer_package_path)

    # warmup before cudagraphing
    with torch.no_grad():
        loaded_transformer(**transformer_kwargs)

    # Apply CUDAGraphs. CUDAGraphs are utilized in torch.compile with mode="max-autotune", but
    # they must be manually applied for torch.export + AOTI.
    loaded_transformer = cudagraph(loaded_transformer)
    pipeline.transformer.forward = loaded_transformer

    # warmup after cudagraphing
    with torch.no_grad():
        pipeline.transformer(**transformer_kwargs)

    # hack to get around export's limitations
    pipeline.vae.forward = pipeline.vae.decode

    vae_decode_kwargs = {"return_dict": False}

    # Possibly serialize model out
    decoder_package_path = os.path.join(
        cache_dir, "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt2"
    )
    if serialize:
        # Apply export
        exported_decoder: torch.export.ExportedProgram = torch.export.export(
            pipeline.vae, args=(_example_tensor(1, 16, 128, 128),), kwargs=vae_decode_kwargs
        )

        # Apply AOTI
        path = torch._inductor.aoti_compile_and_package(
            exported_decoder,
            package_path=decoder_package_path,
            inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
        )
    # download serialized model if needed
    loaded_decoder = load_package(decoder_package_path)

    # warmup before cudagraphing
    with torch.no_grad():
        loaded_decoder(_example_tensor(1, 16, 128, 128), **vae_decode_kwargs)

    loaded_decoder = cudagraph(loaded_decoder)
    pipeline.vae.decode = loaded_decoder

    # warmup for a few iterations
    for _ in range(3):
        pipeline(
            "dummy prompt to trigger torch compilation",
            output_type="pil",
            num_inference_steps=4,
        ).images[0]

    return pipeline