in src/nanotron/models/llama.py [0:0]
def _forward_inference(self, query_states, key_states, value_states, sequence_mask, batch_size, q_length, store):
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
assert key_states.requires_grad is False
assert value_states.requires_grad is False
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
position_offsets = position_ids[:, -1]
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
# interleaved version.
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# non interleaved version.
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
# Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
if self.rotary_embedding.end > old_rotary_embed_end:
k_cache = torch.cat(
[
k_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
v_cache = torch.cat(
[
v_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_v,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
assert (
k_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
assert (
v_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size, kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size, kv_length, self.n_heads, d_v]
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=softmax_scale,
causal=True,
rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
}
)
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
output = self.o_proj(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}