def _forward_packed()

in src/nanotron/models/qwen.py [0:0]


    def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens):
        assert cu_seqlens is not None, "cu_seqlens must be provided for packed attention"
        q = qkv[..., : self.local_num_heads * self.head_dim]  # Not contiguous, similar to flash_attn
        kv = qkv[..., self.local_num_heads * self.head_dim :]  # Not contiguous, similar to flash_attn
        q = q.view(-1, seq_length, self.local_num_heads, self.head_dim)
        kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim)
        if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0:
            seqlen_offset = dist.get_rank(self.cp_pg) * seq_length
            q, kv = self.rotary_emb(
                q, kv, seqlen_offset=seqlen_offset, max_seqlen=seq_length*self.cp_pg_size
            )
        else:
            log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0)
            self.sliding_window_size = None # WARNING: we skip sliding window for no-rope

        q = q.view(-1, self.local_num_heads, self.head_dim)
        kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim)
        max_seqlen = seq_length  # TODO: should this be max position_ids?


        if self.config._attn_implementation == "llama3_ring_attention":
            attn_output = llama3_flash_attn_varlen_kvpacked_func(
                q,
                kv,
                cu_seqlens_q=cu_seqlens["cu_seqlens_q"],
                cu_seqlens_k=cu_seqlens["cu_seqlens_k"],
                max_seqlen_q=cu_seqlens["max_seqlen_q"],
                max_seqlen_k=cu_seqlens["max_seqlen_k"],
                heads_k_stride=self.heads_k_stride,
                local_k_slice=cu_seqlens["local_k_slice"],
                dropout_p=0.0,
                softmax_scale=None,
                causal=True,
                alibi_slopes=None,
                window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
                deterministic=False,
                return_attn_probs=self.log_attn_probs,
                group=self.cp_pg,
            )  # Not contiguous, similar to flash_attn
        else:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            attn_output = flash_attn_varlen_kvpacked_func(
                q,
                kv,
                cu_seqlens,
                cu_seqlens,
                max_seqlen,
                max_seqlen,
                0.0,
                softmax_scale=None,
                causal=True,
                alibi_slopes=None,
                window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
                deterministic=False,
                return_attn_probs=self.log_attn_probs,
            )  # Not contiguous, similar to flash_attn

        if self.log_attn_probs:
            attn_output, attn_probs, _ = attn_output
            # log attn_probs
            self.tbi_logger({"attn_probs": attn_probs})
        # flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730
        return attn_output.reshape(-1, self.local_num_heads * self.head_dim)  # [b*s, num_heads*head_dim]