in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def _concatenate_to_cache(self, key, value, query, attention_mask):
# The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
# but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
def swap_dims(x):
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
# Sanity shape check of cached key against input query.
num_updated_cache_vectors = query.shape[1]
expected_shape = (batch_size, 1, num_heads, head_dim)
if num_updated_cache_vectors == 1 and expected_shape != query.shape:
raise ValueError(
"Autoregressive cache shape error, expected query shape"
f" {expected_shape} instead got {query.shape}"
)
# Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value
# In order to update the key, value caches with the current key and
# value, we move the seq_length axis to the back, similar to what we did for
# the cached ones above.
# Note these are currently the key and value of a single position, since
# we feed one position at a time.
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
# Update key, value caches with our new 1d spatial slices.
# We implement an efficient scatter into the cache via one-hot
# broadcast and addition.
if num_updated_cache_vectors > 1:
indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
key = cached_key.value + jnp.matmul(one_token_key, indices)
value = cached_value.value + jnp.matmul(one_token_value, indices)
else:
one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + num_updated_cache_vectors
# Move the keys and values back to their original shapes.
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
# causal mask for cached decoder self-attention: our single query position should only
# attend to those key positions that have already been generated and cached, not the
# remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
(batch_size,) + (1, num_updated_cache_vectors, seq_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask