optimum/exporters/ipex/cache_utils.py (174 lines of code) (raw):
import os
from typing import List, Optional, Tuple
import intel_extension_for_pytorch as ipex
import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig
from optimum.intel.utils.import_utils import is_ipex_version
class IPEXPagedCache(Cache):
"""
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
ipex-xpu:
ipex-cpu:
Example:
```python
>>> from transformers import AutoTokenizer
>>> from optimum.intel import IPEXModelForCausalLM
>>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache
>>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True)
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = IPEXPagedCache()
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: int,
device=None,
dtype=None,
**kwargs,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu")
device = device or default_device
self.device = device
self._supports_flash_decoding = (
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
)
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
self.slots = torch.zeros([max_cache_len * max_batch_size], dtype=torch.int32, device=device)
torch._dynamo.mark_static_address(self._seen_tokens)
torch._dynamo.mark_static_address(self.slots)
default_block_size = 16 if max_cache_len <= 64 else 64
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
max_batch_size, -1
)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
if getattr(config, "head_dim", None) is not None:
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads
self.head_size = head_size
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
if device.type == "cpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
if self._supports_flash_decoding:
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
else:
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def reshape_and_cache(
self,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
if self.device.type == "xpu" and self._supports_flash_decoding:
# make a WA here as slots here is padded but XPU does not support slots with length not equal to key length, will fix it in IPEX 2.8
valid_len = key.shape[0]
truncated_slots = slots[:valid_len]
PagedAttention.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
truncated_slots,
)
else:
PagedAttention.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
)
# outside the model forward
def alloc_slot_for_prefill(self, input_lens: torch.Tensor, batch_size: int):
all_block_indices = []
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
nb = num_blocks[i]
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1, device=self.device)
block_table = torch.topk(scores, nb).indices
self.block_tables[i][0:nb] = block_table
self.free_blocks[block_table] = 0
slots_range = torch.arange(input_lens[i], device=self.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
all_slot_offsets.append(slot_offsets)
all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets).int()
# Use inplace op to keep the same memory address, avoid recompile
self.slots[: all_block_indices.shape[0]].copy_(all_block_indices * self.block_size + all_slot_offsets)
# outside the model forward
def alloc_slot_for_decode(self, batch_size: int):
start_block_idx = self._seen_tokens // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
# Use inplace op to keep the same memory address, avoid recompile
self.slots.zero_()
for i in range(batch_size):
if slot_offset_in_block[i] == 0:
# need a new block:
b_idx = start_block_idx[i]
if self.block_tables[i][b_idx] == -1:
# Need a free block. Get indices of free blocks, select the first free block
scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1, device=self.device)
self.block_tables[i][b_idx] = scores.argmax()
self.free_blocks[self.block_tables[i][b_idx]] = 0
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
Return:
A tuple containing the updated key and value states.
"""
self.reshape_and_cache(
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
return self._seen_tokens.max()
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len
def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens.zero_()
self.block_tables.fill_(-1)
self.free_blocks.fill_(1)
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
origin_table = self.block_tables.clone()
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
mask = torch.where(self.block_tables == -1, 0, 1)
num_blocks = mask.sum(-1)
updated_table = torch.zeros_like(beam_idx)
for i in range(beam_idx.shape[0]):
nb = num_blocks[i]
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
updated_table[i] = self.block_tables[i][nb - 1]
for layer_idx in range(self.num_hidden_layers):
# The updated_table cannot contain the whole block table, otherwise will cause core-dump.
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx].index_select(
0, updated_table[beam_idx]
)
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx].index_select(
0, updated_table[beam_idx]
)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
max_seq_len = self._seen_tokens.max()
if maximum_length < 0:
maximum_length = max_seq_len - abs(maximum_length)
if max_seq_len <= maximum_length:
return
origin_table = self.block_tables.clone()
for bs in range(self._seen_tokens.shape[0]):
new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len
num_blocks = (new_tokens + self.block_size - 1) // self.block_size
self.block_tables[bs, num_blocks:] = -1
self._seen_tokens[bs] = new_tokens
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1