def _register_attention_mask_for_4_53()

in optimum/exporters/executorch/integrations.py [0:0]


    def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
        if is_transformers_version(">=", "4.53.0.dev0"):
            from transformers.integrations.executorch import sdpa_mask_without_vmap
            from transformers.masking_utils import AttentionMaskInterface
            from transformers.modeling_utils import AttentionInterface

            _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
            if self.use_custom_sdpa:
                if self.use_custom_kv_cache:
                    AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
                    AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
                    # Manually set the attention implementation to custom_sdpa_ring_kv_cache
                    # This handles both regular sdpa and one for sliding window/local attention
                    exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
                else:
                    # Manually set the attention implementation to custom_sdpa_ring_kv_cache
                    # This handles both regular sdpa and one for sliding window/local attention
                    exportable_module.model.model.config._attn_implementation = "custom_sdpa"