def _concatenate_to_cache()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def _concatenate_to_cache(self, key, value, query, attention_mask):
        # The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
        is_initialized = self.has_variable("cache", "cached_key")

        # The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
        # but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
        # fusion optimization. This also enables the "scatter via one-hot
        # broadcast" trick, which means we do a one-hot broadcast instead of a
        # scatter/gather operations, resulting in a 3-4x speedup in practice.
        def swap_dims(x):
            return x[:-3] + tuple(x[i] for i in [-2, -1, -3])

        cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
            # During fast autoregressive decoding, we feed one position at a time,
            # and cache the keys and values step by step.
            # Sanity shape check of cached key against input query.
            num_updated_cache_vectors = query.shape[1]
            expected_shape = (batch_size, 1, num_heads, head_dim)
            if num_updated_cache_vectors == 1 and expected_shape != query.shape:
                raise ValueError(
                    "Autoregressive cache shape error, expected query shape"
                    f" {expected_shape} instead got {query.shape}"
                )

            # Create a OHE of the current index. NOTE: the index is increased below.
            cur_index = cache_index.value

            # In order to update the key, value caches with the current key and
            # value, we move the seq_length axis to the back, similar to what we did for
            # the cached ones above.
            # Note these are currently the key and value of a single position, since
            # we feed one position at a time.
            one_token_key = jnp.moveaxis(key, -3, -1)
            one_token_value = jnp.moveaxis(value, -3, -1)

            # Update key, value caches with our new 1d spatial slices.
            # We implement an efficient scatter into the cache via one-hot
            # broadcast and addition.
            if num_updated_cache_vectors > 1:
                indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
                key = cached_key.value + jnp.matmul(one_token_key, indices)
                value = cached_value.value + jnp.matmul(one_token_value, indices)
            else:
                one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
                key = cached_key.value + one_token_key * one_hot_indices
                value = cached_value.value + one_token_value * one_hot_indices

            cached_key.value = key
            cached_value.value = value
            cache_index.value = cache_index.value + num_updated_cache_vectors

            # Move the keys and values back to their original shapes.
            key = jnp.moveaxis(key, -1, -3)
            value = jnp.moveaxis(value, -1, -3)

            # causal mask for cached decoder self-attention: our single query position should only
            # attend to those key positions that have already been generated and cached, not the
            # remaining zero elements.
            pad_mask = jnp.broadcast_to(
                jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
                (batch_size,) + (1, num_updated_cache_vectors, seq_length),
            )
            attention_mask = combine_masks(pad_mask, attention_mask)

        return key, value, attention_mask