optimum/exporters/ipex/cache_utils.py (174 lines of code) (raw):

import os from typing import List, Optional, Tuple import intel_extension_for_pytorch as ipex import torch from intel_extension_for_pytorch.llm.modules import PagedAttention from transformers import Cache, PretrainedConfig from optimum.intel.utils.import_utils import is_ipex_version class IPEXPagedCache(Cache): """ A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. ipex-xpu: ipex-cpu: Example: ```python >>> from transformers import AutoTokenizer >>> from optimum.intel import IPEXModelForCausalLM >>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache >>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True) >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> past_key_values = IPEXPagedCache() >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation ``` """ def __init__( self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device=None, dtype=None, **kwargs, ) -> None: super().__init__() self.max_batch_size = max_batch_size default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu") device = device or default_device self.device = device self._supports_flash_decoding = ( is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99") ) # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device) self.slots = torch.zeros([max_cache_len * max_batch_size], dtype=torch.int32, device=device) torch._dynamo.mark_static_address(self._seen_tokens) torch._dynamo.mark_static_address(self.slots) default_block_size = 16 if max_cache_len <= 64 else 64 self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size))) self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( max_batch_size, -1 ) self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device) self.max_cache_len = max_cache_len self.num_kv_heads = config.num_key_value_heads self.num_hidden_layers = config.num_hidden_layers if getattr(config, "head_dim", None) is not None: head_size = config.head_dim else: head_size = config.hidden_size // config.num_attention_heads self.head_size = head_size self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] if device.type == "cpu": key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) elif device.type == "xpu": if self._supports_flash_decoding: key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) else: key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) for i in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) def reshape_and_cache( self, key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, ): # TODO: unify API definition between CPU and XPU in IPEX version > 2.6 if self.device.type == "xpu" and self._supports_flash_decoding: # make a WA here as slots here is padded but XPU does not support slots with length not equal to key length, will fix it in IPEX 2.8 valid_len = key.shape[0] truncated_slots = slots[:valid_len] PagedAttention.reshape_and_cache_flash( key, value, key_cache, value_cache, truncated_slots, ) else: PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slots, ) # outside the model forward def alloc_slot_for_prefill(self, input_lens: torch.Tensor, batch_size: int): all_block_indices = [] all_slot_offsets = [] num_blocks = (input_lens + self.block_size - 1) // self.block_size for i in range(batch_size): nb = num_blocks[i] scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1, device=self.device) block_table = torch.topk(scores, nb).indices self.block_tables[i][0:nb] = block_table self.free_blocks[block_table] = 0 slots_range = torch.arange(input_lens[i], device=self.device) block_indices = slots_range // self.block_size slot_offsets = slots_range % self.block_size all_block_indices.append(self.block_tables[i][block_indices]) all_slot_offsets.append(slot_offsets) all_block_indices = torch.cat(all_block_indices) all_slot_offsets = torch.cat(all_slot_offsets).int() # Use inplace op to keep the same memory address, avoid recompile self.slots[: all_block_indices.shape[0]].copy_(all_block_indices * self.block_size + all_slot_offsets) # outside the model forward def alloc_slot_for_decode(self, batch_size: int): start_block_idx = self._seen_tokens // self.block_size slot_offset_in_block = (self._seen_tokens) % self.block_size # Use inplace op to keep the same memory address, avoid recompile self.slots.zero_() for i in range(batch_size): if slot_offset_in_block[i] == 0: # need a new block: b_idx = start_block_idx[i] if self.block_tables[i][b_idx] == -1: # Need a free block. Get indices of free blocks, select the first free block scores = self.free_blocks * torch.arange(self.free_blocks.shape[0], 0, -1, device=self.device) self.block_tables[i][b_idx] = scores.argmax() self.free_blocks[self.block_tables[i][b_idx]] = 0 self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. Return: A tuple containing the updated key and value states. """ self.reshape_and_cache( key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) return self.key_cache[layer_idx], self.value_cache[layer_idx] def get_seq_length(self) -> int: """Returns the sequence length of the cached states that were seen by the model.""" return self._seen_tokens.max() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len def reset(self): """Resets the cache values while preserving the objects""" self._seen_tokens.zero_() self.block_tables.fill_(-1) self.free_blocks.fill_(1) def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" origin_table = self.block_tables.clone() updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device)) mask = torch.where(self.block_tables == -1, 0, 1) num_blocks = mask.sum(-1) updated_table = torch.zeros_like(beam_idx) for i in range(beam_idx.shape[0]): nb = num_blocks[i] self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1] updated_table[i] = self.block_tables[i][nb - 1] for layer_idx in range(self.num_hidden_layers): # The updated_table cannot contain the whole block table, otherwise will cause core-dump. self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx].index_select( 0, updated_table[beam_idx] ) self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx].index_select( 0, updated_table[beam_idx] ) free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) for i in free_table: if not (self.block_tables == i).any(): self.free_blocks[i] = 1 def crop(self, maximum_length: int): """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" max_seq_len = self._seen_tokens.max() if maximum_length < 0: maximum_length = max_seq_len - abs(maximum_length) if max_seq_len <= maximum_length: return origin_table = self.block_tables.clone() for bs in range(self._seen_tokens.shape[0]): new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len num_blocks = (new_tokens + self.block_size - 1) // self.block_size self.block_tables[bs, num_blocks:] = -1 self._seen_tokens[bs] = new_tokens free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) for i in free_table: if not (self.block_tables == i).any(): self.free_blocks[i] = 1