def add_to_kv_cache()

in optimum/graphcore/generation/attention_mixin.py [0:0]


    def add_to_kv_cache(self, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Copies the key-value pair into their corresponding key-value caches.

        Args:
            key (`torch.FloatTensor`): key tensor of shape `(batch_size * num_beams, num_heads, 1, head_dim)`.
            value (`torch.FloatTensor`): value tensor of shape `(batch_size * num_beams, num_heads, 1, head_dim)`.
        """
        if not self.kv_cache_initialized:
            raise ValueError(
                f"{self.__class__.__name__} assumes that self-attention has KV caching enabled. "
                f"Please instantiate using `{self.__class__.__name__}.from_model()` so the KV "
                "cache can be created."
            )

        if self.training:
            raise RuntimeError("KV caching is currently only supported for inference.")

        expected_key_shape, expected_value_shape = list(self._k_cache.shape), list(self._v_cache.shape)
        expected_key_shape[-2] = 1
        expected_value_shape[-2] = 1
        if list(key.shape) != expected_key_shape:
            raise ValueError(f"Expected key shape {expected_key_shape}, received {list(key.shape)}.")
        if list(value.shape) != expected_value_shape:
            raise ValueError(f"Expected value shape {expected_value_shape}, received {list(value.shape)}.")

        # For now assume that generation will always start from step 0.
        reset_kv_cache = self._generation_step == 0
        self._k_cache *= 1 - reset_kv_cache.to(self._k_cache.dtype)
        self._v_cache *= 1 - reset_kv_cache.to(self._v_cache.dtype)

        if hasattr(self, "_beam_idx"):
            # For beam search, permute the cache since inputs are permuted on host.
            _k_cache = torch.index_select(self._k_cache, 0, self._beam_idx)
            _v_cache = torch.index_select(self._v_cache, 0, self._beam_idx)
            self._k_cache.copy_(_k_cache)
            self._v_cache.copy_(_v_cache)

        # Dynamic update leads to uneven tile placement, and scatter leads to large re-arrangements,
        # so use a brute force matmul approach which empirically seems best for now.
        bsz, heads, src_len, head_dim = self._k_cache.shape
        mm_mask = (torch.arange(src_len) == self._generation_step).view(src_len, 1)
        _key = torch.matmul(mm_mask.to(key.dtype), key.view(bsz * heads, 1, head_dim))
        _value = torch.matmul(mm_mask.to(value.dtype), value.view(bsz * heads, 1, head_dim))
        self._k_cache += _key.view(self._k_cache.shape)
        self._v_cache += _value.view(self._v_cache.shape)

        return self._k_cache, self._v_cache