optimum/exporters/executorch/integrations.py (350 lines of code) (raw):
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import torch
from packaging.version import parse
from torch.export import ExportedProgram
from torch.nn.attention import SDPBackend
from transformers import (
AutoProcessor,
PreTrainedModel,
StaticCache,
T5ForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.generation.configuration_utils import GenerationConfig
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
from optimum.utils.import_utils import is_transformers_version
from .utils import save_config_to_constant_methods
class CausalLMExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a Causal LM model exportable with `torch.export`.
This module ensures that the exported model is compatible with ExecuTorch.
"""
def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False):
super().__init__()
self.model = model
self.config = model.config
self.use_custom_kv_cache = use_custom_kv_cache
self.use_custom_sdpa = use_custom_sdpa
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")
def _prepare_export_inputs(self):
"""
Prepare example inputs and configurations for export.
Returns:
example_input_ids (torch.Tensor): Example input IDs tensor.
example_cache_position (torch.Tensor): Example cache position tensor.
dynamic_shapes (dict or None): Dynamic shape specifications for export.
strict (bool): Whether to use strict export mode.
"""
# Default values for legacy or fallback cases
example_input_ids = torch.tensor([[1]], dtype=torch.long)
example_cache_position = torch.tensor([0], dtype=torch.long)
dynamic_shapes = None
strict = True
is_using_hybrid_cache_wo_custom_sdpa_kv_cache = (
hasattr(self.config, "layer_types")
and getattr(self.config, "sliding_window", None) is not None
and not (self.use_custom_kv_cache and self.use_custom_sdpa)
)
if is_transformers_version(">", "4.52.0") and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache:
# Prepare inputs with dynamic shapes
seq_length = 3 # Sequence length > 1 to avoid specialization issues
example_input_ids = torch.zeros((1, seq_length), dtype=torch.long)
example_cache_position = torch.arange(seq_length, dtype=torch.long)
max_seq_len = self.metadata.get("get_max_seq_len")
sliding_window = self.metadata.get("sliding_window", float("inf"))
max_dim = min(max_seq_len, sliding_window) - 1
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim)
dynamic_shapes = {
"input_ids": {1: seq_len_dim},
"cache_position": {0: seq_len_dim},
}
strict = parse(torch.__version__) != parse("2.7.0") # Workaround for PyTorch bug #150994
return example_input_ids, example_cache_position, dynamic_shapes, strict
def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
if is_transformers_version(">=", "4.53.0.dev0"):
from transformers.integrations.executorch import sdpa_mask_without_vmap
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
if self.use_custom_sdpa:
if self.use_custom_kv_cache:
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
else:
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"
def export(
self,
) -> Dict[str, ExportedProgram]:
input_ids, cache_position, dynamic_shapes, strict = self._prepare_export_inputs()
logging.info(
f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}"
)
if is_transformers_version(">", "4.52.0"):
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
)
exportable_module = TorchExportableModuleForDecoderOnlyLM(
self.model,
max_batch_size=1,
max_cache_len=self.metadata.get("get_max_seq_len"),
)
self._register_attention_mask_for_4_53(exportable_module)
if self.use_custom_kv_cache:
from optimum.executorch.attentions.custom_kv_cache import (
replace_with_et_custom_kv_cache,
)
replace_with_et_custom_kv_cache(
exportable_module.model,
self.model.config,
self.model.generation_config,
self.model.dtype,
)
with torch.no_grad():
exported_program = exportable_module.export(input_ids, cache_position, dynamic_shapes, strict)
# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
# e.g. output of update_cache is transposed and
# input to custom_sdpa is transposed.
from executorch.extension.llm.export.export_passes import (
RemoveRedundantTransposes,
)
mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0]
exported_program = torch.export.export(
mutated_gm,
args=(input_ids, cache_position),
kwargs={},
dynamic_shapes=dynamic_shapes,
strict=strict,
)
else:
# Path to use legacy API, static export only due to pinned transformers version
from transformers.integrations.executorch import (
convert_and_export_with_cache,
)
exported_program = convert_and_export_with_cache(self.model, input_ids, cache_position)
return {"model": exported_program}
class VisionEncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a vision encoder-only model exportable with `torch.export`.
This module ensures that the exported model is compatible with ExecuTorch.
"""
def __init__(self, model):
super().__init__()
self.model = model
self.config = model.config
# Metadata to be recorded in the pte model file
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
def forward(self, pixel_values):
print(f"DEBUG: pixel_values: {pixel_values.shape}")
print(f"DEBUG: forward: {self.model.method_meta('forward')}")
return self.model(pixel_values=pixel_values)
def export(self, pixel_values=None) -> Dict[str, ExportedProgram]:
if pixel_values is None:
batch_size = 1
num_channels = self.config.num_channels
height = self.config.image_size
width = self.config.image_size
pixel_values = torch.rand(batch_size, num_channels, height, width)
with torch.no_grad():
return {
"model": torch.export.export(
self.model,
args=(),
kwargs={"pixel_values": pixel_values},
strict=False,
)
}
class MaskedLMExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a Masked LM model exportable with `torch.export`.
This module ensures that the exported model is compatible with ExecuTorch.
"""
def __init__(self, model):
super().__init__()
self.model = model
self.config = model.config
# Metadata to be recorded in the pte model file
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask)
def export(self, input_ids=None, attention_mask=None) -> Dict[str, ExportedProgram]:
max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 64)
max_seq_length = max(max_position_embeddings - 1, 1)
# Create dummy inputs with expected shapes
batch_size = 1
seq_length = max_seq_length
vocab_size = self.model.config.vocab_size
# Create example inputs (no need for tokenizer)
dummy_input_ids = (
torch.randint(0, vocab_size, (batch_size, seq_length), dtype=torch.long)
if input_ids is None
else input_ids
)
dummy_attention_mask = (
torch.ones((batch_size, seq_length), dtype=torch.long) if attention_mask is None else attention_mask
)
# Define dynamic shapes with Dim objects, always use Auto
dynamic_shapes = {
"input_ids": {1: torch.export.Dim.AUTO},
"attention_mask": {1: torch.export.Dim.AUTO},
}
# Export the model with dynamic dimensions
with torch.no_grad():
return {
"model": torch.export.export(
self.model,
args=(dummy_input_ids,),
kwargs={"attention_mask": dummy_attention_mask},
dynamic_shapes=dynamic_shapes,
strict=True,
)
}
class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
This module ensures that the exported encoder model is compatible with ExecuTorch.
"""
def __init__(self, encoder_model):
super().__init__()
self.encoder = encoder_model
self.config = encoder_model.config
def forward(self, input_ids):
return self.encoder(input_ids).last_hidden_state
class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
specifically for use with static caching. This module ensures the exported decoder
is compatible with ExecuTorch.
"""
def __init__(self, model, max_static_cache_length, batch_size):
super().__init__()
# Get the decoder component
self.decoder = model.get_decoder()
if isinstance(model, WhisperForConditionalGeneration):
self.proj_out = model.proj_out
else:
self.proj_out = model.lm_head
self.config = model.config
# Initialize static cache
self.static_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device="cpu",
dtype=torch.float32,
)
# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
use_cache=True,
cache_position=cache_position,
)
# Apply linear projection (lm head) to obtain logits
logits = self.proj_out(outputs[0])
return logits
class Seq2SeqLMExportableModule(torch.nn.Module):
def __init__(
self,
model: PreTrainedModel,
batch_size=1,
max_hidden_seq_length=4096,
cache_implementation="static",
max_cache_length=1024,
):
super().__init__()
self.full_model = model
self.encoder = model.get_encoder()
self.config = model.config
self.max_hidden_seq_length = max_hidden_seq_length
self.generation_config = GenerationConfig(
use_cache=True,
max_length=max_cache_length,
cache_implementation=cache_implementation,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_length,
},
)
if isinstance(self.full_model, WhisperForConditionalGeneration):
self._processor = AutoProcessor.from_pretrained(model.config._name_or_path)
self._expected_encoder_input_shape = torch.Size(
(
1,
self._processor.feature_extractor.feature_size,
self._processor.feature_extractor.nb_max_frames,
)
)
additional_configs = {}
additional_configs["max_hidden_seq_length"] = max_hidden_seq_length
# Metadata to be recorded in the pte model file
self.metadata = save_config_to_constant_methods(
self.config,
self.generation_config,
**additional_configs,
)
self.exported_encoder = None
self.exported_decoder = None
def _export_encoder(self, encoder_input_ids):
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
# Define dynamic sequence length for encoder
if isinstance(self.full_model, WhisperForConditionalGeneration):
assert (
encoder_input_ids.shape == self._expected_encoder_input_shape
), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}.
For more infromation, please refer to the Whisper preprocessor config."""
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export."
)
# Export the encoder
with torch.no_grad():
exported_encoder = torch.export.export(
wrapped_encoder,
(encoder_input_ids,),
dynamic_shapes=dynamic_shapes,
strict=True,
)
return exported_encoder
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
wrapped_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
model=self.full_model,
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
batch_size=self.generation_config.cache_config.batch_size,
)
.to("cpu")
.eval()
)
if isinstance(self.full_model, WhisperForConditionalGeneration):
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export."
)
# Export the decoder
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
dynamic_shapes=dynamic_shapes,
strict=True,
)
return exported_decoder
def export(
self,
encoder_input_ids=None,
decoder_input_ids=None,
encoder_hidden_states=None,
cache_position=None,
) -> Dict[str, ExportedProgram]:
if encoder_input_ids is None:
if isinstance(self.full_model, WhisperForConditionalGeneration):
example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape)
else:
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long)
else:
example_encoder_input_ids = encoder_input_ids
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
if not encoder_hidden_states:
example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids)
else:
example_encoder_hidden_states = encoder_hidden_states
example_decoder_input_ids = (
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
self.exported_decoder = self._export_decoder(
example_decoder_input_ids,
example_encoder_hidden_states,
example_cache_position,
)
return {
"encoder": self.exported_encoder,
"decoder": self.exported_decoder,
}
def generate(self, prompt_token_ids, max_new_tokens):
with torch.no_grad():
# Run encoder
encoder_output = self.exported_encoder.module()(prompt_token_ids)
# Initialize with start token (0 for T5)
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
generated_ids = [0]
# Generate tokens one by one
for i in range(max_new_tokens - 1):
# Run decoder for next token prediction
logits = self.exported_decoder.module()(
decoder_input_ids,
encoder_output,
torch.tensor([i], dtype=torch.long),
)
# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)
# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)
# Check if EOS token
if next_token == self.config.eos_token_id:
break
return generated_ids