optimum/executorch/attentions/custom_kv_cache.py [58:96]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            layer_cache = CustomKVCache(
                max_batch_size=self.max_batch_size,
                max_context_length=self.max_cache_len,
                n_heads=self.num_key_value_heads,
                head_dim=self.head_dim,
                dtype=dtype,
            )
            self.kv_cache.append(layer_cache)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
        using ExecutorTorch's CustomKVCache.

        Args:
            key_states (`torch.Tensor`):
                The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
            value_states (`torch.Tensor`):
                The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache update.

        Returns:
            A tuple containing the updated key and value states.
        """
        assert cache_kwargs is not None

        # Get cache position from cache_kwargs (used by StaticCache)
        cache_position = cache_kwargs.get("cache_position")
        assert cache_position is not None
        assert isinstance(cache_position, torch.Tensor)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/executorch/attentions/custom_kv_cache.py [226:264]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                layer_cache = CustomKVCache(
                    max_batch_size=self.max_batch_size,
                    max_context_length=self.max_cache_len,
                    n_heads=self.num_key_value_heads,
                    head_dim=self.head_dim,
                    dtype=dtype,
                )
            self.kv_cache.append(layer_cache)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
        using ExecutorTorch's CustomKVCache or CustomRingKVCache depending on the layer type.

        Args:
            key_states (`torch.Tensor`):
                The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
            value_states (`torch.Tensor`):
                The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache update.

        Returns:
            A tuple containing the updated key and value states.
        """
        assert cache_kwargs is not None

        # Get cache position from cache_kwargs (used by HybridCache)
        cache_position = cache_kwargs.get("cache_position")
        assert cache_position is not None
        assert isinstance(cache_position, torch.Tensor)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



