optimum/habana/transformers/models/persimmon/modeling_persimmon.py [68:119]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0:
                # When token_idx is used, static seq len = (input token len + max output token len)
                kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
            else:
                kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

        # Partial rotary embedding
        query_rot, query_pass = (
            query_states[..., : self.rotary_ndims],
            query_states[..., self.rotary_ndims :],
        )
        key_rot, key_pass = (
            key_states[..., : self.rotary_ndims],
            key_states[..., self.rotary_ndims :],
        )
        # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
        query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos[position_ids], sin[position_ids])

        # [batch_size, seq_length, num_heads, head_dim]
        query_states = torch.cat((query_rot, query_pass), dim=-1)
        key_states = torch.cat((key_rot, key_pass), dim=-1)

        if past_key_value is not None:
            if token_idx is not None:
                if 0 <= self.layer_idx < len(past_key_value.key_cache):
                    past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
                    past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
                    key_states = past_key_value.key_cache[self.layer_idx]
                    value_states = past_key_value.value_cache[self.layer_idx]
                else:
                    past_key_value.key_cache.append(key_states)
                    past_key_value.value_cache.append(value_states)
            else:
                # Specific to RoPE models with partial rotation
                cache_kwargs = {
                    "sin": sin,
                    "cos": cos,
                    "partial_rotation_size": self.rotary_ndims,
                    "cache_position": cache_position,
                }
                key_states, value_states = past_key_value.update(
                    key_states, value_states, self.layer_idx, cache_kwargs
                )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/habana/transformers/models/stablelm/modeling_stablelm.py [66:117]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0:
                # When token_idx is used, static seq len = (input token len + max output token len)
                kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
            else:
                kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

        # Partial rotary embedding
        query_rot, query_pass = (
            query_states[..., : self.rotary_ndims],
            query_states[..., self.rotary_ndims :],
        )
        key_rot, key_pass = (
            key_states[..., : self.rotary_ndims],
            key_states[..., self.rotary_ndims :],
        )
        # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
        query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos[position_ids], sin[position_ids])

        # [batch_size, seq_length, num_heads, head_dim]
        query_states = torch.cat((query_rot, query_pass), dim=-1)
        key_states = torch.cat((key_rot, key_pass), dim=-1)

        if past_key_value is not None:
            if token_idx is not None:
                if 0 <= self.layer_idx < len(past_key_value.key_cache):
                    past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
                    past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
                    key_states = past_key_value.key_cache[self.layer_idx]
                    value_states = past_key_value.value_cache[self.layer_idx]
                else:
                    past_key_value.key_cache.append(key_states)
                    past_key_value.value_cache.append(value_states)
            else:
                # Specific to RoPE models with partial rotation
                cache_kwargs = {
                    "sin": sin,
                    "cos": cos,
                    "partial_rotation_size": self.rotary_ndims,
                    "cache_position": cache_position,
                }
                key_states, value_states = past_key_value.update(
                    key_states, value_states, self.layer_idx, cache_kwargs
                )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



