optimum/neuron/models/inference/backend/modules/attention/attention_base.py (323 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/attention/attention_base.py
import logging
import math
import warnings
from enum import Enum
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torch.distributed import ProcessGroup
from transformers import PretrainedConfig
from .utils import (
apply_rotary_pos_emb,
distributed_softmax,
manual_softmax,
move_heads_front,
repeat_kv,
)
# Try except for the compatibility with older compiler version
try:
from neuronxcc.nki._private_kernels.attention import attention_isa_kernel # noqa: E402
except ImportError:
from neuronxcc.nki.kernels.attention import attention_isa_kernel # noqa: E402
import neuronx_distributed as nxd
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers import parallel_state, utils # noqa: E402
from neuronx_distributed.parallel_layers.layers import SPMDRank
from neuronx_distributed.parallel_layers.parallel_state import get_kv_shared_group
from neuronxcc.nki.language import nc
from torch_neuronx.xla_impl.ops import nki_jit # noqa: E402
from ...config import NxDNeuronConfig
from .gqa import GQA, GroupQueryAttention_O, GroupQueryAttention_QKV # noqa: E402
logger = logging.getLogger("Neuron")
_flash_fwd_call = nki_jit()(attention_isa_kernel)
class FlashAttentionStrategy(Enum):
NONE = 0
UNSHARDED_KERNEL = 1
SHARDED_KERNEL = 2
class NeuronAttentionBase(nn.Module):
"""
This base attention class implements the core Neuron related adaptation including
1. replaces the q_proj, k_proj, v_proj with column parallel layer
2. replaces the o_proj with row parallel layer
3. update self.num_head to be self.num_head / tp_degree
4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree
5. update forward() method to adjust to changes from self.num_head
"""
def __init__(
self,
config: PretrainedConfig,
neuron_config: NxDNeuronConfig,
tensor_model_parallel_group: Optional[ProcessGroup] = None,
qkv_proj_bias: bool = False,
o_proj_bias: bool = False,
qk_scale: Optional[float] = None,
):
if not parallel_state.model_parallel_is_initialized():
raise ValueError(
"Neuron Attention has to be initialized in a distributed env. Please use neuronx_distributed"
" module to initialize a distributed env."
)
super().__init__()
if tensor_model_parallel_group is not None:
self.tensor_model_parallel_group = tensor_model_parallel_group
self.rank_util = SPMDRank(world_size=self.tensor_model_parallel_group.size())
else:
self.tensor_model_parallel_group = nxd.parallel_layers.parallel_state.get_tensor_model_parallel_group()
self.rank_util = SPMDRank(world_size=self.tensor_model_parallel_group.size())
self.is_causal = True
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.padding_side = neuron_config.padding_side
self.torch_dtype = neuron_config.torch_dtype
self.qk_layernorm = neuron_config.qk_layernorm
self.flash_decoding_enabled = neuron_config.flash_decoding_enabled
self.num_cores_per_group = neuron_config.num_cores_per_group
self.rpl_reduce_dtype = neuron_config.rpl_reduce_dtype
self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled
self.rms_norm_eps = config.rms_norm_eps
self.tp_degree = neuron_config.tp_degree
self.fused_qkv = neuron_config.fused_qkv
self.clip_qkv = None
self.qk_scale = qk_scale
self.o_proj_layer_name = "o_proj"
if (self.head_dim * self.num_attention_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_attention_heads})."
)
self.sequence_parallel_enabled = neuron_config.sequence_parallel_enabled
self.sequence_dimension = 1 if self.sequence_parallel_enabled else None
self.qkv_proj = GroupQueryAttention_QKV(
hidden_size=self.hidden_size,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
tp_degree=self.tp_degree,
dtype=self.torch_dtype,
bias=qkv_proj_bias,
gather_output=False,
fused_qkv=self.fused_qkv,
clip_qkv=self.clip_qkv,
sequence_parallel_enabled=self.sequence_parallel_enabled,
sequence_dimension=self.sequence_dimension,
tensor_model_parallel_group=self.tensor_model_parallel_group,
rms_norm_eps=self.rms_norm_eps,
qkv_kernel_enabled=neuron_config.qkv_kernel_enabled,
logical_nc_config=neuron_config.logical_nc_config,
)
self.o_proj = GroupQueryAttention_O(
hidden_size=self.hidden_size,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
tp_degree=self.tp_degree,
dtype=self.torch_dtype,
bias=o_proj_bias,
input_is_parallel=True,
layer_name=self.o_proj_layer_name,
sequence_parallel_enabled=self.sequence_parallel_enabled,
sequence_dimension=self.sequence_dimension,
tensor_model_parallel_group=self.tensor_model_parallel_group,
rpl_reduce_dtype=self.rpl_reduce_dtype,
)
self.num_heads = utils.divide(self.qkv_proj.get_num_attention_heads(), self.tp_degree)
self.num_key_value_heads = utils.divide(self.qkv_proj.get_num_key_value_heads(), self.tp_degree)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_layernorm = nn.LayerNorm(self.head_dim) if self.qk_layernorm else None
self.k_layernorm = nn.LayerNorm(self.head_dim) if self.qk_layernorm else None
self.attn_kernel_enabled = neuron_config.attn_kernel_enabled
self.logical_nc_config = neuron_config.logical_nc_config
def scaled_qk(self, Q, K, attention_mask):
qk_scale = self.qk_scale or (1.0 / math.sqrt(self.head_dim))
QK = torch.matmul(Q, K.transpose(2, 3)) * qk_scale
QK = torch.where(attention_mask, QK, torch.finfo(QK.dtype).min)
return QK
def prep_qkv_tensors(
self,
position_ids,
hidden_states,
past_key_value,
cos_cache=None,
sin_cache=None,
rmsnorm=None,
):
"""take care of the shape, layout, group query, custom position encoding, etc."""
Q, K, V = self.qkv_proj(hidden_states=hidden_states, rmsnorm=rmsnorm)
# Divide hidden_dim across heads for MHA
# Change layout: BSHD -> BHSD
bsz, q_len, _ = hidden_states.size()
if self.sequence_parallel_enabled:
q_len *= self.tensor_model_parallel_group.size()
Q = move_heads_front(Q, bsz, q_len, self.num_heads, self.head_dim, layernorm=self.q_layernorm)
K = move_heads_front(K, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=self.k_layernorm)
V = move_heads_front(V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None)
# Rotate Q and K
if self.rotary_emb is not None:
if cos_cache is None or sin_cache is None:
cos_cache, sin_cache = self.rotary_emb(V, position_ids)
Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache)
return Q, K, V, cos_cache, sin_cache
def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask) -> Tensor:
"""attention computation at prefilling (context encoding) phase"""
K_active = repeat_kv(K, self.num_key_value_groups)
V_active = repeat_kv(V, self.num_key_value_groups)
flash_attn_strategy = self.get_flash_attention_strategy(q_len)
logger.debug(f"Flash attention strategy: {flash_attn_strategy}")
if flash_attn_strategy != FlashAttentionStrategy.NONE:
logger.debug(f"ATTN kernel: logical_nc_config={self.logical_nc_config}")
# if we are using left padding, then the bzs needs be 1 (otherwise we get wrong result
# because flash attention does not use attention_mask). In practice, we use right
# padding so this is unlikely to cause issues
assert self.padding_side == "right" or bsz == 1
# original shape of q, k, v is BHSD, and expected output is also BHSD.
logger.debug(f"Using flash_fwd for Q.shape={Q.shape}")
# make sure to cast inputs to torch_dtype (this is needed because the downcast to bf16
# might happen after the kernel hlo creation step). Also convert shapes as expected by the kernel.
# original Q shape: batch, num_heads, seqlen, d_head
Q = (
Q.permute(0, 1, 3, 2) # after permute: batch, num_heads, d_head, seqlen
.reshape((bsz * self.num_heads, self.head_dim, q_len))
.to(self.torch_dtype)
)
Q = Q / math.sqrt(self.head_dim)
K_active = (
K_active.permute(0, 1, 3, 2).reshape((bsz * self.num_heads, self.head_dim, q_len)).to(self.torch_dtype)
)
V_active = V_active.reshape((bsz * self.num_heads, q_len, self.head_dim)).to(self.torch_dtype)
# shape: (B*H)DS
attn_output = torch.zeros(bsz * self.num_heads, self.head_dim, q_len, dtype=Q.dtype, device=Q.device)
logger.debug("Input parameter shapes")
logger.debug(f"Q input shape {Q.shape}")
logger.debug(f"K input shape {K_active.shape}")
logger.debug(f"V input shape {V_active.shape}")
logger.debug(f"Attn output shape {attn_output.shape}")
if flash_attn_strategy == FlashAttentionStrategy.SHARDED_KERNEL:
grid = (nc(self.logical_nc_config),)
_flash_fwd_call[grid](
Q,
K_active,
V_active,
1.0,
attn_output,
kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap",
)
elif flash_attn_strategy == FlashAttentionStrategy.UNSHARDED_KERNEL:
_flash_fwd_call(
Q,
K_active,
V_active,
1.0,
attn_output,
kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap",
)
else:
raise ValueError(f"Invalid flash attention strategy: {flash_attn_strategy}")
# shape: BHDS
attn_output = attn_output.reshape((bsz, self.num_heads, self.head_dim, q_len))
logger.debug(f"Attn output after reshape {attn_output.shape}")
else:
logger.debug("ATTN: native compiler")
logger.debug(f"Not using flash_fwd for Q.shape={Q.shape}")
active_scores = self.scaled_qk(Q, K_active, attention_mask)
active_scores = nn.functional.softmax(active_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
attn_output = torch.matmul(active_scores, V_active)
return attn_output, flash_attn_strategy
def get_flash_attention_strategy(self, q_len) -> FlashAttentionStrategy:
"""
Gets the flash attention strategy.
For LNC1, use the unsharded kernel if sequence length is at least 4096 to get the best performance.
The unsharded kernel requires a sequence length of at least 512.
For LNC2, use the sharded kernel if sequence length is divisible by 1024. Otherwise, use no
kernel, because the unsharded kernel has worse performance than no kernel.
The sharded kernel requires a sequence length of at least 1024.
These constraints may change later.
TODO: Throw an exception instead of disabling flash attention if explicitly enabled but not eligible.
This must consider bucketing to avoid throwing an exception for smaller buckets.
"""
if self.qk_scale is not None:
# If a custom qk_scale is provided, flash attention is not supported.
return FlashAttentionStrategy.NONE
if int(self.logical_nc_config) > 1:
if q_len < 1024:
return FlashAttentionStrategy.NONE
if q_len % 1024 == 0:
return FlashAttentionStrategy.SHARDED_KERNEL
else:
warnings.warn("Flash attention disabled. LNC2 requires seq_len % 1024 for flash attn to be performant")
return FlashAttentionStrategy.NONE
# If seq_len is at least 4096, enable flash attn automatically to improve performance.
if q_len >= 4096:
return FlashAttentionStrategy.UNSHARDED_KERNEL
# At lower seq lens, enable only if explicitly enabled.
if self.attn_kernel_enabled and q_len >= 512:
return FlashAttentionStrategy.UNSHARDED_KERNEL
return FlashAttentionStrategy.NONE
def compute_for_flash_decoding(self, Q, K, V, past_key_value, attention_mask, active_mask) -> Tensor:
# TODO: refactor/decompose this to reduce duplication with compute_for_token_gen
# active attention
n_repeat = Q.shape[1]
K_active = repeat_kv(K, n_repeat)
V_active = repeat_kv(V, n_repeat)
active_scores = (torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim)).to(torch.float32)
active_scores = torch.where(active_mask, active_scores, torch.finfo(active_scores.dtype).min)
# prior attention
K_prior = repeat_kv(past_key_value[0], n_repeat)
V_prior = repeat_kv(past_key_value[1], n_repeat)
prior_scores = torch.matmul(Q, K_prior.transpose(2, 3)) / math.sqrt(self.head_dim)
prior_scores = torch.where(attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min)
prior_scores = prior_scores.to(torch.float32)
# attention scores
softmax_prior, softmax_active = distributed_softmax(prior_scores, active_scores)
softmax_prior, softmax_active = softmax_prior.to(Q.dtype), softmax_active.to(Q.dtype)
attn_prior = torch.matmul(softmax_prior, V_prior)
attn_active = torch.matmul(softmax_active, V_active)
attn_output = attn_prior + attn_active
return attn_output
def compute_for_token_gen(self, Q, K, V, position_ids, past_key_value, attention_mask, active_mask) -> Tensor:
"""attention computation at token generation phase"""
is_speculation = position_ids.shape[-1] > 1
# Attention computation: softmax((Q.K/√dkv) + mask).V
# i. prior (cached) KV
K_prior = past_key_value[0]
V_prior = past_key_value[1]
K_prior = repeat_kv(K_prior, self.num_key_value_groups)
V_prior = repeat_kv(V_prior, self.num_key_value_groups)
prior_scores = torch.matmul(Q, K_prior.transpose(2, 3)) / math.sqrt(self.head_dim)
prior_scores = torch.where(attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min)
prior_scores = prior_scores.to(torch.float32)
# ii. active (current/new) KV
K_active = repeat_kv(K, self.num_key_value_groups)
V_active = repeat_kv(V, self.num_key_value_groups)
active_scores = torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim)
if is_speculation:
active_scores = torch.where(active_mask, active_scores, torch.finfo(active_scores.dtype).min)
active_scores = active_scores.to(torch.float32)
# iii. attention scores
softmax_prior, softmax_active = manual_softmax(prior_scores, active_scores, is_speculation)
softmax_prior, softmax_active = softmax_prior.to(Q.dtype), softmax_active.to(Q.dtype)
attn_prior = torch.matmul(softmax_prior, V_prior)
attn_active = torch.matmul(softmax_active, V_active)
attn_output = attn_prior + attn_active
return attn_output
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
active_mask: Optional[torch.LongTensor] = None,
cos_cache: Optional[torch.Tensor] = None,
sin_cache: Optional[torch.Tensor] = None,
rmsnorm=None,
) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]:
"""Implements each layer's forward pass for the attention block."""
bsz, q_len, _ = hidden_states.size()
if self.sequence_parallel_enabled:
q_len *= self.tensor_model_parallel_group.size()
Q, K, V, cos_cache, sin_cache = self.prep_qkv_tensors(
position_ids,
hidden_states,
past_key_value,
cos_cache=cos_cache,
sin_cache=sin_cache,
rmsnorm=rmsnorm,
)
flash_attn_strategy = FlashAttentionStrategy.NONE
if past_key_value is None:
attn_output, flash_attn_strategy = self.perform_prefill(Q, K, V, q_len, bsz, attention_mask)
if self.flash_decoding_enabled:
assert self.qkv_proj.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE, (
"Flash decoding lives in the context of GQA (grouped query attention) and traditional MHA "
"multi-head attention) won't work!"
)
rank_id = self.rank_util.get_rank()
rank_id_in_kv_group = torch.remainder(rank_id, self.num_cores_per_group).to(torch.int64)
# shard KV by seq len and pick the values based on rank
assert q_len == Q.shape[2], f"Q shape is {Q.shape}"
# selecting positions (on S dim) that belongs to the current rank
offset = torch.arange(0, q_len, self.num_cores_per_group, dtype=torch.int64, device=Q.device)
selected_seq_pos = offset + rank_id_in_kv_group
K = torch.index_select(input=K, dim=2, index=selected_seq_pos)
V = torch.index_select(input=V, dim=2, index=selected_seq_pos)
else:
if self.flash_decoding_enabled:
assert active_mask is not None, "Flash decoding requires active mask is not None!"
# gather Q from all cores in its KV group
groups = get_kv_shared_group(as_list=True)
Q = xm.all_gather(Q, dim=1, groups=groups, pin_layout=False)
attn_output = self.compute_for_flash_decoding(Q, K, V, past_key_value, attention_mask, active_mask)
attn_output = xm.reduce_scatter(
xm.REDUCE_SUM,
attn_output,
scale=1,
scatter_dim=1,
shard_count=len(groups[0]),
groups=groups,
pin_layout=False,
)
else:
attn_output = self.compute_for_token_gen(
Q, K, V, position_ids, past_key_value, attention_mask, active_mask
)
if flash_attn_strategy != FlashAttentionStrategy.NONE:
# transpose BHDS -> BSHD
# this layout avoids additional transposes between attention kernel and output projection
attn_output = attn_output.permute(0, 3, 1, 2)
else:
# transpose BHSD -> BSHD
attn_output = attn_output.transpose(1, 2).contiguous()
# merge multi head hidden
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
# Z = Z.Wo
attn_output = self.o_proj(attn_output)
past_key_value: Tuple[Tensor, Tensor] = (K, V)
return attn_output, past_key_value, cos_cache, sin_cache