in optimum/exporters/executorch/integrations.py [0:0]
def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
if is_transformers_version(">=", "4.53.0.dev0"):
from transformers.integrations.executorch import sdpa_mask_without_vmap
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
if self.use_custom_sdpa:
if self.use_custom_kv_cache:
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
else:
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"