optimum/executorch/attentions/custom_kv_cache.py (228 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, Optional, Tuple, Union import torch # If transformers is not installed, raise an ImportError try: from transformers.cache_utils import HybridCache, StaticCache except ImportError: raise ImportError("transformers is not installed. Please install it to use Static/HybridCache.") try: from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( CustomKVCache, CustomRingKVCache, ) except ImportError: raise ImportError("ExecutorTorch is not installed. Please install it to use Custom Cache.") class ETCustomStaticCache(StaticCache): """ Custom KV Cache implementation for ExecutorTorch that inherits from Hugging Face's StaticCache but uses custom operations for cache updates similar to ExecutorTorch's CustomStaticCache. """ def __init__( self, config, max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ): super().__init__( config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype, layer_device_map=layer_device_map, ) # make sure layer_device_map is none assert layer_device_map is None assert device is None or device == "cpu", "Device must be None or 'cpu'" # Create a list of CustomKVCache instances, one per layer self.kv_cache = torch.nn.ModuleList() for _ in range(config.num_hidden_layers): layer_cache = CustomKVCache( max_batch_size=self.max_batch_size, max_context_length=self.max_cache_len, n_heads=self.num_key_value_heads, head_dim=self.head_dim, dtype=dtype, ) self.kv_cache.append(layer_cache) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` using ExecutorTorch's CustomKVCache. Args: key_states (`torch.Tensor`): The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] value_states (`torch.Tensor`): The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache update. Returns: A tuple containing the updated key and value states. """ assert cache_kwargs is not None # Get cache position from cache_kwargs (used by StaticCache) cache_position = cache_kwargs.get("cache_position") assert cache_position is not None assert isinstance(cache_position, torch.Tensor) # Get the CustomKVCache instance for this layer layer_cache = self.kv_cache[layer_idx] # Use the CustomKVCache's update method # CustomKVCache expects input_pos, k_val, v_val and handles the transpose internally k_out, v_out = layer_cache.update( input_pos=cache_position, k_val=key_states, v_val=value_states, ) return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value # This is different from StaticCache which checks the 3rd dim if layer_idx is None: layer_idx = 0 return (self.kv_cache[layer_idx].k_cache[0, :, 0].any(dim=-1)).sum() @classmethod def from_legacy_cache( cls, config, legacy_cache, max_cache_len=None, device=None, dtype=None, ): """ Create an ETCustomStaticCache from a legacy cache implementation. Args: config: The model configuration legacy_cache: The legacy cache implementation max_cache_len: The maximum cache length device: The device for the new cache dtype: The data type for the new cache Returns: A new ETCustomStaticCache instance """ assert hasattr(legacy_cache, "k_cache") and hasattr(legacy_cache, "v_cache") # Extract dimensions from the legacy cache assert len(legacy_cache.k_cache.shape) == 4 if legacy_cache.k_cache.shape[1] == legacy_cache.n_heads: # Shape is [batch_size, n_heads, seq_len, head_dim] max_batch_size = legacy_cache.k_cache.shape[0] else: # Shape is [batch_size, seq_len, n_heads, head_dim] max_batch_size = legacy_cache.k_cache.shape[0] # Use the legacy cache's device and dtype if not specified if device is None and hasattr(legacy_cache, "device"): device = legacy_cache.device elif device is None and hasattr(legacy_cache.k_cache, "device"): device = legacy_cache.k_cache.device if dtype is None and hasattr(legacy_cache, "dtype"): dtype = legacy_cache.dtype elif dtype is None and hasattr(legacy_cache.k_cache, "dtype"): dtype = legacy_cache.k_cache.dtype assert device is None or device == "cpu" assert dtype is None or dtype == torch.float32 # Use the legacy cache's max_seq_len if max_cache_len is not specified if max_cache_len is None and hasattr(legacy_cache, "max_seq_len"): max_cache_len = legacy_cache.max_seq_len elif max_cache_len is None and hasattr(legacy_cache, "max_cache_len"): max_cache_len = legacy_cache.max_cache_len return cls( config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype, ) # Need to figure out if I have to inherit from HybridCache or StaticCache class ETCustomHybridCache(HybridCache): """ Custom Hybrid KV Cache implementation for ExecutorTorch that inherits from Hugging Face's HybridCache but uses ExecutorTorch's CustomKVCache for global layers and CustomRingKVCache for sliding window layers. """ def __init__( self, config, max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ): super().__init__( config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype, layer_device_map=layer_device_map, ) # make sure layer_device_map is none assert layer_device_map is None assert device is None or device == "cpu", "Device must be None or 'cpu'" self.cache_position = None # Create a list of cache instances, one per layer # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers self.kv_cache = torch.nn.ModuleList() for layer_idx in range(config.num_hidden_layers): # newer version of transfomer has is_sliding defined # for HybridCache if self.is_sliding[layer_idx]: # This is a sliding window layer layer_cache = CustomRingKVCache( max_batch_size=self.max_batch_size, max_context_length=self.sliding_window_len, n_heads=self.num_key_value_heads, head_dim=self.head_dim, dtype=dtype, ) else: layer_cache = CustomKVCache( max_batch_size=self.max_batch_size, max_context_length=self.max_cache_len, n_heads=self.num_key_value_heads, head_dim=self.head_dim, dtype=dtype, ) self.kv_cache.append(layer_cache) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type. Args: key_states (`torch.Tensor`): The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] value_states (`torch.Tensor`): The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache update. Returns: A tuple containing the updated key and value states. """ assert cache_kwargs is not None # Get cache position from cache_kwargs (used by HybridCache) cache_position = cache_kwargs.get("cache_position") assert cache_position is not None assert isinstance(cache_position, torch.Tensor) self.cache_position = cache_position # Get the cache instance for this layer (either CustomKVCache or CustomRingKVCache) layer_cache = self.kv_cache[layer_idx] # Use the cache's update method # Both CustomKVCache and CustomRingKVCache have the same update interface k_out, v_out = layer_cache.update( input_pos=cache_position, k_val=key_states, v_val=value_states, ) return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" if layer_idx is None: layer_idx = 0 # For CustomRingKVCache, we need to handle the sequence length differently layer_cache = self.kv_cache[layer_idx] if self.is_sliding[layer_idx]: # CustomRingKVCache cache_position_manager which # maintains cache position for each slot in the kv cache # we return the max position + 1 to indicate max position # seen so far. Not sure if thats the correct interpretation # of sequence length return layer_cache.cache_positions_manager.cache_positions.max().item() + 1 return (layer_cache.k_cache[0, :, 0].any(dim=-1)).sum() def get_layer_cache(self, layer_idx: int): """ Get the cache for a specific layer. This method is dynamo-traceable. Args: layer_idx (int): The layer index Returns: The cache instance for the specified layer (CustomKVCache or CustomRingKVCache) """ return self.kv_cache[layer_idx] def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): """ Replace all KV caches in the module with ETCustomStaticCache. This modifies the model in place. Args: module: The module to modify config: The model configuration Returns: The modified module """ # Recursively replace KV caches return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype) def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): """ Helper function to recursively replace KV caches in the module. Args: module: The module to modify config: The model configuration Returns: The modified module """ # Check if module has static_cache (TorchExportableModuleWithStaticCache) if hasattr(module, "static_cache"): assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}" # TODO: Add replace_cache to exported module # in transformer's executorch.py if getattr(module, "replace_cache", None) is not None: static_cache = ETCustomStaticCache( config=config, max_batch_size=generation_config.cache_config.batch_size, max_cache_len=generation_config.cache_config.max_cache_len, device=generation_config.cache_config.device, dtype=cache_dtype, ) module.replace_cache(static_cache) else: module.static_cache = ETCustomStaticCache( config=config, max_batch_size=generation_config.cache_config.batch_size, max_cache_len=generation_config.cache_config.max_cache_len, device=generation_config.cache_config.device, dtype=cache_dtype, ) # Dont know why we need to this even though # CustomKVCache registers the attributes for i in range(len(module.static_cache.kv_cache)): setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache) setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache) # Check if module has cache (TorchExportableModuleWithHybridCache) elif hasattr(module, "cache"): assert isinstance(module.cache, HybridCache), f"Expected HybridCache, got {type(module.cache)}" # Replace with ETCustomHybridCache if getattr(module, "replace_cache", None) is not None: hybrid_cache = ETCustomHybridCache( config=config, max_batch_size=generation_config.cache_config.batch_size, max_cache_len=generation_config.cache_config.max_cache_len, device=generation_config.cache_config.device, dtype=cache_dtype, ) module.replace_cache(hybrid_cache) else: module.cache = ETCustomHybridCache( config=config, max_batch_size=generation_config.cache_config.batch_size, max_cache_len=generation_config.cache_config.max_cache_len, device=generation_config.cache_config.device, dtype=cache_dtype, ) # Register cache attributes for each layer for i in range(len(module.cache.kv_cache)): setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache) setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache) if module.cache.is_sliding[i]: # Register cache_positions as buffer for sliding window layers # This prevents it from being traced as a constant module.register_buffer( f"cache_positions_{i}", module.cache.kv_cache[i].cache_positions_manager.cache_positions, persistent=False, ) else: raise ValueError( "Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) " "or 'cache' (TorchExportableModuleWithHybridCache) attribute" ) return module