in optimum/graphcore/generation/attention_mixin.py [0:0]
def add_to_kv_cache(self, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Copies the key-value pair into their corresponding key-value caches.
Args:
key (`torch.FloatTensor`): key tensor of shape `(batch_size * num_beams, num_heads, 1, head_dim)`.
value (`torch.FloatTensor`): value tensor of shape `(batch_size * num_beams, num_heads, 1, head_dim)`.
"""
if not self.kv_cache_initialized:
raise ValueError(
f"{self.__class__.__name__} assumes that self-attention has KV caching enabled. "
f"Please instantiate using `{self.__class__.__name__}.from_model()` so the KV "
"cache can be created."
)
if self.training:
raise RuntimeError("KV caching is currently only supported for inference.")
expected_key_shape, expected_value_shape = list(self._k_cache.shape), list(self._v_cache.shape)
expected_key_shape[-2] = 1
expected_value_shape[-2] = 1
if list(key.shape) != expected_key_shape:
raise ValueError(f"Expected key shape {expected_key_shape}, received {list(key.shape)}.")
if list(value.shape) != expected_value_shape:
raise ValueError(f"Expected value shape {expected_value_shape}, received {list(value.shape)}.")
# For now assume that generation will always start from step 0.
reset_kv_cache = self._generation_step == 0
self._k_cache *= 1 - reset_kv_cache.to(self._k_cache.dtype)
self._v_cache *= 1 - reset_kv_cache.to(self._v_cache.dtype)
if hasattr(self, "_beam_idx"):
# For beam search, permute the cache since inputs are permuted on host.
_k_cache = torch.index_select(self._k_cache, 0, self._beam_idx)
_v_cache = torch.index_select(self._v_cache, 0, self._beam_idx)
self._k_cache.copy_(_k_cache)
self._v_cache.copy_(_v_cache)
# Dynamic update leads to uneven tile placement, and scatter leads to large re-arrangements,
# so use a brute force matmul approach which empirically seems best for now.
bsz, heads, src_len, head_dim = self._k_cache.shape
mm_mask = (torch.arange(src_len) == self._generation_step).view(src_len, 1)
_key = torch.matmul(mm_mask.to(key.dtype), key.view(bsz * heads, 1, head_dim))
_value = torch.matmul(mm_mask.to(value.dtype), value.view(bsz * heads, 1, head_dim))
self._k_cache += _key.view(self._k_cache.shape)
self._v_cache += _value.view(self._v_cache.shape)
return self._k_cache, self._v_cache