optimum/executorch/attentions/custom_sdpa.py (103 lines of code) (raw):

# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Tuple, Union import torch from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa def custom_sdpa_with_start_pos_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"], # noqa scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: # This is before the transpose max_seq_len = key.shape[2] # FA2 uses non-transposed inputs query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Convert the hell out of the inputs to fp32 and back input_dtype = query.dtype query = query.to(torch.float32) key = key.to(torch.float32) value = value.to(torch.float32) # Ignore the causal flag from kwargs but use the one in module kwargs.pop("is_causal", None) assert module.is_causal, "Current variant supports only causal attention" is_causal = module.is_causal if kwargs.get("is_sliding", False): is_causal = False attn_mask = attention_mask # start_pos is not important when using mask # instead of doing causal attention start_pos = 0 else: attn_mask = None # Calculate the input pos from attention mask. # Branch out for float vs bool mask # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." attention_mask = attention_mask.reshape(-1, max_seq_len) first_row_mask = attention_mask[0, :] # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 output = torch.ops.llama.custom_sdpa( query, key, value, start_pos=start_pos, attn_mask=attn_mask, drpout_p=0.0, is_causal=is_causal, scale=scaling, ) return output.to(input_dtype), None def get_custom_sdpa_for_ring_kv_cache( exportable_module: torch.nn.Module, ) -> Callable: # lazy importing to avoid version dependent class definition from executorch import version try: from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( CustomRingKVCache, ) except ImportError: raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.") def _custom_sdpa_for_ring_kv_cache( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"], # noqa scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: is_sliding = getattr(module, "is_sliding", False) if is_sliding: # lazy import to avoid being in the optimum import path # for et <= 0.6.0 version from optimum.executorch.attentions.custom_kv_cache import ETCustomHybridCache layer_idx = module.layer_idx assert layer_idx is not None, "layer_idx is not set for sliding window attention." hybrid_cache = exportable_module.model.cache assert isinstance(hybrid_cache, ETCustomHybridCache), f"Expected HybridCache, got {type(hybrid_cache)}" ring_cache = hybrid_cache.get_layer_cache(layer_idx) assert isinstance(ring_cache, CustomRingKVCache), f"Expected CustomRingKVCache, got {type(ring_cache)}" input_pos = hybrid_cache.cache_position[0].item() seqlen = query.shape[2] attention_mask = ring_cache.create_causal_mask_for_ring_buffer(input_pos, seqlen) kwargs.update({"is_sliding": True}) return custom_sdpa_with_start_pos_forward( module, query, key, value, attention_mask, scaling, softcap, head_mask, **kwargs, ) else: return custom_sdpa_with_start_pos_forward( module, query, key, value, attention_mask, scaling, softcap, head_mask, **kwargs, ) return _custom_sdpa_for_ring_kv_cache