optimum/neuron/models/inference/backend/modules/kvcache/kv_cache_manager.py (170 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/kvcache/kv_cache_manager.py
import logging
from typing import List
import torch
from neuronx_distributed.parallel_layers import utils
from torch import Tensor, nn
from transformers import PretrainedConfig
from ...config import NxDNeuronConfig
from ..attention.gqa import (
determine_sharding_strategy,
get_shardable_head_counts,
)
from ..flashdecode.utils import get_cache_size
from .utils import dynamic_update_slice, fill_prefix
def _reshape_tiled_cache(cache: Tensor):
# We merge the tiles BHS(128 tiled)D -> BHSD
cache_shape = cache.shape
desired_shape = (
cache_shape[0],
cache_shape[1],
cache_shape[2] * cache_shape[3],
cache_shape[4],
)
cache = cache.reshape(desired_shape)
return cache
def _slice_kv_cacheline(padding_side: str, seq_len: int, cache: Tensor):
if padding_side == "right":
return torch.ops.aten.slice(cache, dim=2, start=0, end=seq_len)
max_idx = cache.shape[2]
return torch.ops.aten.slice(cache, dim=2, start=max_idx - seq_len, end=max_idx)
class KVCacheManager(nn.Module):
"""
Key Value Cache Management.
It stores KV cache as a parameter list of the shape (batch_sz, num_kv_head_per_rank, max_len, head_dim),
and vends out read and write operations.
"""
def __init__(self, config: PretrainedConfig, neuron_config: NxDNeuronConfig, **kwargs):
super().__init__()
self.padding_side = neuron_config.padding_side
self.is_continuous_batching = neuron_config.continuous_batching
self.flash_decoding_enabled = neuron_config.flash_decoding_enabled
self.num_cores_per_group = neuron_config.num_cores_per_group
self.num_kv_head = kwargs["num_kv_head"]
# NOTE: Tiling the sequence dimension of the KV cache enables specific compiler optimizations like cascaded reductions
self.is_kv_cache_tiled = False # TODO: enable this when compiler fixes CR 158191111 (as per NxDI comment)
self._init_kv_shape(config, neuron_config)
num_layer = config.num_hidden_layers
dtype = neuron_config.torch_dtype
self.past_key_values = nn.ParameterList(
[nn.Parameter(torch.zeros(self.kv_shape, dtype=dtype), requires_grad=False) for _ in range(num_layer * 2)]
)
def _get_num_kv_heads_per_rank(self, config: PretrainedConfig, neuron_config: NxDNeuronConfig):
tp_degree = neuron_config.tp_degree
num_kv_head = self.num_kv_head
num_atten_head = config.num_attention_heads
gqa_sharding_strategy = determine_sharding_strategy(tp_degree, num_kv_head)
_, num_key_value_heads = get_shardable_head_counts(
tp_degree, num_atten_head, num_kv_head, gqa_sharding_strategy
)
return utils.divide(num_key_value_heads, tp_degree)
def _get_hidden_dim_per_head(self, config: PretrainedConfig):
hidden_size = config.hidden_size
num_atten_head = config.num_attention_heads
hidden_dim_per_head = hidden_size // num_atten_head
return hidden_dim_per_head
def _init_kv_shape(self, config: PretrainedConfig, neuron_config: NxDNeuronConfig):
max_batch_size = neuron_config.max_batch_size
max_len = neuron_config.sequence_length
num_kv_heads_per_rank = self._get_num_kv_heads_per_rank(config, neuron_config)
hidden_dim_per_head = self._get_hidden_dim_per_head(config)
if self.flash_decoding_enabled:
padded_max_len = max_len
if max_len % self.num_cores_per_group != 0:
padded_max_len += self.num_cores_per_group - max_len % self.num_cores_per_group
logging.warning(
f"Max length needs to be multiples of num_cores_per_group {self.num_cores_per_group}"
f" but got {max_len}. Padding it to {padded_max_len} meet the requirement."
)
max_len = get_cache_size(padded_max_len, self.num_cores_per_group)
if self.is_kv_cache_tiled:
num_tiles = int(max_len / 128)
# KV cache layout : BHS(128 tiled)D
self.kv_shape = (
max_batch_size,
num_kv_heads_per_rank,
128, # Sequence dim is tiled
num_tiles, # max_len = 128 * num_tiles
hidden_dim_per_head,
)
else:
# KV cache layout : BHSD
self.kv_shape = (
max_batch_size,
num_kv_heads_per_rank,
max_len,
hidden_dim_per_head,
)
def _fetch_cache(self, idx: int, kvcache_buffer=None):
if kvcache_buffer is not None:
return kvcache_buffer[idx][0], kvcache_buffer[idx][1]
k_cache, v_cache = self.past_key_values[idx * 2], self.past_key_values[idx * 2 + 1]
if self.is_kv_cache_tiled:
return _reshape_tiled_cache(k_cache), _reshape_tiled_cache(v_cache)
return k_cache, v_cache
def get_kv_by_layer_id(self, key_layer_idx, gather_index=None, slice_index=None):
k_cache = self.past_key_values[key_layer_idx]
v_cache = self.past_key_values[key_layer_idx + 1]
return k_cache, v_cache
def get_cache(self, seq_len: int, skip_slice=False, **kwargs):
"""
Return network (all layers)'s previously cached K and V, up to seq_len.
:param seq_len: sequence length (or bucket size from auto-bucketing e.g. 128, 512, 1024 etc.)
:return: list of tuple of (K, V)
"""
slice_index, gather_index = None, None
past_key_values = []
for key_layer_idx in range(0, len(self.past_key_values), 2):
# get kv per layer
k_cache, v_cache = self.get_kv_by_layer_id(
key_layer_idx, gather_index=gather_index, slice_index=slice_index
)
if self.is_kv_cache_tiled:
k_cache = _reshape_tiled_cache(k_cache)
v_cache = _reshape_tiled_cache(v_cache)
# slice for partial view
if not skip_slice:
k_cache = _slice_kv_cacheline(self.padding_side, seq_len, k_cache)
v_cache = _slice_kv_cacheline(self.padding_side, seq_len, v_cache)
past_key_values.append([k_cache, v_cache])
return past_key_values
def update_cache(
self,
is_for_context_encoding: bool,
seq_ids: Tensor,
position_ids: Tensor,
new_key_values: List[Tensor],
seq_len: int,
scatter_index=None,
active_mask=None,
kvcache_buffer=None,
):
"""
Given the passed-in new_key_values, update the cache
:param scatter_index: tensor representing index to update
:param is_for_context_encoding: bool
:param seq_ids: tensor of size (batch_sz)
:param position_ids: tensor of size (batch_sz, bucket_sz)
:param new_key_values: list of tuple, the latest kv obtained at the end of the network from forward pass
:param seq_len: sequence length (or bucket size from auto-bucketing e.g. 128, 512, 1024 etc.)
:param scatter_index: tensor representing index to update
:param active_mask: tensor representing index to update
:param kvcache_buffer: if passed key states are updates to this buffer.
kvcache_buffer is 2D list where, 1st dim for layer and the second denotes K and V.
For example,
kvcache_buffer[1][0] is the K cache of the 1st layer
kvcache_buffer[4][1] is the V cache of the 4th layer
:return: list of tuple of (K, V)
"""
updated_kv_cache = []
for idx, kv_per_layer in enumerate(new_key_values):
latest_k, latest_v = kv_per_layer[0], kv_per_layer[1]
k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer)
if is_for_context_encoding:
if self.is_continuous_batching:
assert seq_ids.dim() == 1 and seq_ids.shape[0] == 1, "only supports single seq_id"
cache_idx = seq_ids
indices = torch.zeros(k_cache.dim(), dtype=seq_ids.dtype, device=seq_ids.device)
indices = indices.scatter(
dim=0,
index=torch.tensor([0], dtype=torch.int64, device=k_cache.device),
src=cache_idx,
).to(torch.int32)
indices = indices.split(1)
indices = [t.squeeze() for t in indices]
k_cache = dynamic_update_slice(k_cache, latest_k, indices)
v_cache = dynamic_update_slice(v_cache, latest_v, indices)
else:
k_cache = fill_prefix(k_cache, latest_k)
v_cache = fill_prefix(v_cache, latest_v)
else:
if self.padding_side == "left":
k_cache = k_cache[:, :, 1:, :]
v_cache = v_cache[:, :, 1:, :]
k_cache = torch.cat([k_cache, latest_k], dim=2)
v_cache = torch.cat([v_cache, latest_v], dim=2)
else:
# copy the tensor of the new position into kv cache
if self.flash_decoding_enabled:
assert active_mask is not None, "active_mask should be specified for flash decoding!"
garbage_pos = seq_len - 1 # treat last pos as garbage
updated_pos_ids = position_ids // self.num_cores_per_group
scatter_index = torch.where(active_mask == 1, updated_pos_ids, garbage_pos)
scatter_index_new = scatter_index.view(-1, 1, scatter_index.shape[-1], 1).expand_as(latest_k)
else:
scatter_index_new = self._get_index_to_update_new_position(
scatter_index, position_ids, latest_k
)
k_cache = torch.scatter(input=k_cache, dim=2, index=scatter_index_new, src=latest_k)
v_cache = torch.scatter(input=v_cache, dim=2, index=scatter_index_new, src=latest_v)
# Retiling
# TODO once compiler fixes CR 158191111 we can turn back output tiling on
# k_cache = k_cache.view(cache_shape)
# v_cache = v_cache.view(cache_shape)
updated_kv_cache.append(k_cache)
updated_kv_cache.append(v_cache)
# return updated kv cache to NxD runtime
return updated_kv_cache
def _get_index_to_update_new_position(self, scatter_index, position_ids, full_k):
scatter_index = position_ids.view(-1, 1, position_ids.shape[-1], 1).expand_as(full_k)
return scatter_index