optimum/habana/transformers/models/chatglm/modeling_chatglm.py [540:648]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    ):
        # hidden_states: [sq, b, h]
        q_len, bsz, hiddenSize = hidden_states.size()

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                3 * self.hidden_size_per_attention_head,
            )
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        if prefix_encoder is not None:
            prefix_encoder_key, prefix_encoder_value = prefix_encoder
            if mixed_x_layer.dtype == torch.float8_e4m3fn:
                from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2

                prefix_encoder_key = cast_to_fp8_v2(prefix_encoder_key, None, False, False, mixed_x_layer.dtype)[0]
                prefix_encoder_value = cast_to_fp8_v2(prefix_encoder_value, None, False, False, mixed_x_layer.dtype)[0]
            else:
                prefix_encoder_key = prefix_encoder_key.to(mixed_x_layer.dtype)
                prefix_encoder_value = prefix_encoder_value.to(mixed_x_layer.dtype)

            key_layer = torch.cat((prefix_encoder_key, key_layer), dim=0)
            value_layer = torch.cat((prefix_encoder_value, value_layer), dim=0)

        query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
        key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
        value_layer = value_layer.permute(1, 2, 0, 3).contiguous()

        if use_cache:
            # reuse k, v, self_attention
            if reuse_cache:
                key_layer = self.k_cache(key_layer, 2, token_idx)
                value_layer = self.v_cache(value_layer, 2, token_idx)
                past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
            else:
                if past_key_value is None:
                    past_key = torch.zeros(
                        key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
                    )
                    past_value = torch.zeros(
                        key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
                    )
                    past_key_value = [past_key, past_value]
                key_layer = self.k_cache.update(past_key_value[0], key_layer, 2, token_idx, self.inp_seq_len)
                value_layer = self.v_cache.update(past_key_value[1], value_layer, 2, token_idx, self.inp_seq_len)
                if token_idx is None:
                    past_key_value = (key_layer, value_layer)

            if cache_idx is not None and q_len == 1:
                key_layer = key_layer[:, :, :cache_idx, :]
                value_layer = value_layer[:, :, :cache_idx, :]
                if attention_mask is not None:
                    attention_mask = attention_mask[:, :, :, :cache_idx]
        else:
            past_key_value = None

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
            attention_mask,
            cache_position,
            attn_softmax_bf16,
            use_flash_attention,
            flash_attention_recompute,
            flash_attention_causal_mask,
            flash_attention_fast_softmax,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/habana/transformers/models/glm4v/modeling_chatglm.py [645:754]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    ):
        # hidden_states: [sq, b, h]
        q_len, bsz, hiddenSize = hidden_states.size()

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                3 * self.hidden_size_per_attention_head,
            )
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        if prefix_encoder is not None:
            prefix_encoder_key, prefix_encoder_value = prefix_encoder
            if mixed_x_layer.dtype == torch.float8_e4m3fn:
                from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2

                prefix_encoder_key = cast_to_fp8_v2(prefix_encoder_key, None, False, False, mixed_x_layer.dtype)[0]
                prefix_encoder_value = cast_to_fp8_v2(prefix_encoder_value, None, False, False, mixed_x_layer.dtype)[0]
            else:
                prefix_encoder_key = prefix_encoder_key.to(mixed_x_layer.dtype)
                prefix_encoder_value = prefix_encoder_value.to(mixed_x_layer.dtype)

            key_layer = torch.cat((prefix_encoder_key, key_layer), dim=0)
            value_layer = torch.cat((prefix_encoder_value, value_layer), dim=0)

        query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
        key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
        value_layer = value_layer.permute(1, 2, 0, 3).contiguous()

        if use_cache:
            # reuse k, v, self_attention
            if reuse_cache:
                key_layer = self.k_cache(key_layer, 2, token_idx)
                value_layer = self.v_cache(value_layer, 2, token_idx)
                past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
            else:
                if past_key_value is None:
                    past_key = torch.zeros(
                        key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
                    )
                    past_value = torch.zeros(
                        key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
                    )
                    past_key_value = [past_key, past_value]
                key_layer = self.k_cache.update(past_key_value[0], key_layer, 2, token_idx, self.inp_seq_len)
                value_layer = self.v_cache.update(past_key_value[1], value_layer, 2, token_idx, self.inp_seq_len)
                if token_idx is None:
                    past_key_value = (key_layer, value_layer)

            if cache_idx is not None and q_len == 1:
                key_layer = key_layer[:, :, :cache_idx, :]
                value_layer = value_layer[:, :, :cache_idx, :]
                if attention_mask is not None:
                    attention_mask = attention_mask[:, :, :, :cache_idx]
        else:
            past_key_value = None

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
            attention_mask,
            cache_position,
            attn_softmax_bf16,
            use_flash_attention,
            flash_attention_recompute,
            flash_attention_causal_mask,
            flash_attention_fast_softmax,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



