optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py [1271:1331]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        hidden_states_q = hidden_states
        hidden_states_kv = hidden_states
        self.split_kv_b_proj()
        q_position_ids = position_ids
        kv_position_ids = position_ids
        bsz, q_len, _ = hidden_states_q.size()

        if self.q_lora_rank is None:
            q = self.q_proj(hidden_states_q)
        else:
            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states_q)))

        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)

        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        kv_seq_len = q_pe.shape[-2]

        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            if token_idx is None:
                if hasattr(past_key_value, "get_usable_length"):
                    kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
                else:
                    kv_seq_len += past_key_value[0].shape[-2]
            else:
                if reuse_cache:
                    kv_seq_len = past_key_value[0][-2]
                else:
                    kv_seq_len = past_key_value[0].shape[-2]

        cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
        q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)
        q_nope = torch.matmul(q_nope.transpose(0, 1), self.q_absorb).transpose(0, 1)
        compressed_kv, k_pe = self.compress_kv(hidden_states_kv, kv_position_ids)

        # update & get all compressed_kv, k_pe
        if use_cache:
            if reuse_cache:
                if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor):
                    # prefix tuning case. attach past_key_value to generate first token.
                    compressed_kv = torch.cat((past_key_value[0], compressed_kv), -2)
                    k_pe = torch.cat((past_key_value[1], k_pe), -2)

                compressed_kv = self.k_cache(compressed_kv, 1, token_idx)

                k_pe = self.v_cache(k_pe, 1, token_idx)
                past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())

            else:
                if past_key_value is None:
                    dtype_1 = hidden_states.dtype
                    device_1 = hidden_states.device
                    past_key = torch.zeros(compressed_kv.shape, dtype=dtype_1, device=device_1)
                    past_value = torch.zeros(k_pe.shape, dtype=dtype_1, device=device_1)
                    past_key_value = (past_key, past_value)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py [1067:1127]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            hidden_states_q = hidden_states
            hidden_states_kv = hidden_states
            self.split_kv_b_proj()
            q_position_ids = position_ids
            kv_position_ids = position_ids
            bsz, q_len, _ = hidden_states_q.size()

            if self.q_lora_rank is None:
                q = self.q_proj(hidden_states_q)
            else:
                q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states_q)))

            q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)

            q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

            kv_seq_len = q_pe.shape[-2]

            if past_key_value is not None:
                if self.layer_idx is None:
                    raise ValueError(
                        f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                        "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                        "with a layer index."
                    )
                if token_idx is None:
                    if hasattr(past_key_value, "get_usable_length"):
                        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
                    else:
                        kv_seq_len += past_key_value[0].shape[-2]
                else:
                    if reuse_cache:
                        kv_seq_len = past_key_value[0][-2]
                    else:
                        kv_seq_len = past_key_value[0].shape[-2]

            cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
            q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)
            q_nope = torch.matmul(q_nope.transpose(0, 1), self.q_absorb).transpose(0, 1)
            compressed_kv, k_pe = self.compress_kv(hidden_states_kv, kv_position_ids)

            # update & get all compressed_kv, k_pe
            if use_cache:
                if reuse_cache:
                    if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor):
                        # prefix tuning case. attach past_key_value to generate first token.
                        compressed_kv = torch.cat((past_key_value[0], compressed_kv), -2)
                        k_pe = torch.cat((past_key_value[1], k_pe), -2)

                    compressed_kv = self.k_cache(compressed_kv, 1, token_idx)

                    k_pe = self.v_cache(k_pe, 1, token_idx)
                    past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())

                else:
                    if past_key_value is None:
                        dtype_1 = hidden_states.dtype
                        device_1 = hidden_states.device
                        past_key = torch.zeros(compressed_kv.shape, dtype=dtype_1, device=device_1)
                        past_value = torch.zeros(k_pe.shape, dtype=dtype_1, device=device_1)
                        past_key_value = (past_key, past_value)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



