optimum/tpu/static_cache_xla.py (23 lines of code) (raw):
from typing import Any, Dict, Optional, Tuple
import torch
from transformers import StaticCache
class StaticCacheXla(StaticCache):
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
# `index_copy_(dim, index, source)` functions similarly to `tensor[index] = source`,
# but it is used for better generality and it uses less memory on XLA.
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
key_cache = self.key_cache[layer_idx]
device = key_cache.device
# index_select(dim, index) performs the same operation as item = tensor[..., index, ...]
# but it is used for better generality and it uses less memory on XLA.
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
item = key_cache.index_select(0, torch.tensor(0, device=device))
head = item.index_select(1, torch.tensor(0, device=device))
return head.any(dim=-1).sum()