in src/nanotron/models/qwen.py [0:0]
def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens):
assert cu_seqlens is not None, "cu_seqlens must be provided for packed attention"
q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn
kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn
q = q.view(-1, seq_length, self.local_num_heads, self.head_dim)
kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim)
if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0:
seqlen_offset = dist.get_rank(self.cp_pg) * seq_length
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=seq_length*self.cp_pg_size
)
else:
log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0)
self.sliding_window_size = None # WARNING: we skip sliding window for no-rope
q = q.view(-1, self.local_num_heads, self.head_dim)
kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim)
max_seqlen = seq_length # TODO: should this be max position_ids?
if self.config._attn_implementation == "llama3_ring_attention":
attn_output = llama3_flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens_q=cu_seqlens["cu_seqlens_q"],
cu_seqlens_k=cu_seqlens["cu_seqlens_k"],
max_seqlen_q=cu_seqlens["max_seqlen_q"],
max_seqlen_k=cu_seqlens["max_seqlen_k"],
heads_k_stride=self.heads_k_stride,
local_k_slice=cu_seqlens["local_k_slice"],
dropout_p=0.0,
softmax_scale=None,
causal=True,
alibi_slopes=None,
window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
deterministic=False,
return_attn_probs=self.log_attn_probs,
group=self.cp_pg,
) # Not contiguous, similar to flash_attn
else:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
attn_output = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
alibi_slopes=None,
window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
deterministic=False,
return_attn_probs=self.log_attn_probs,
) # Not contiguous, similar to flash_attn
if self.log_attn_probs:
attn_output, attn_probs, _ = attn_output
# log attn_probs
self.tbi_logger({"attn_probs": attn_probs})
# flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730
return attn_output.reshape(-1, self.local_num_heads * self.head_dim) # [b*s, num_heads*head_dim]