in optimum/exporters/executorch/integrations.py [0:0]
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
wrapped_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
model=self.full_model,
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
batch_size=self.generation_config.cache_config.batch_size,
)
.to("cpu")
.eval()
)
if isinstance(self.full_model, WhisperForConditionalGeneration):
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
dynamic_shapes = {
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export."
)
# Export the decoder
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
dynamic_shapes=dynamic_shapes,
strict=True,
)
return exported_decoder