optimum/neuron/models/inference/backend/modules/attention/utils.py (188 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/attention/utils.py
import math
from typing import Any, Dict, Tuple
import torch
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_kv_shared_group,
get_tensor_model_parallel_group,
)
from neuronx_distributed.parallel_layers.utils import get_padding_length
from torch import Tensor, nn
torch.manual_seed(0)
weight_cache = {}
def _get_weight_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch.Tensor:
if prefix in weight_cache:
return weight_cache[prefix]
if (prefix + "weight") in state_dict:
transposed_weight = state_dict[prefix + "weight"].t()
weight_cache[prefix] = transposed_weight
return transposed_weight
else:
raise RuntimeError(f"Cannot find {(prefix + 'weight')} in the state_dict")
def _set_weight_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: Dict[str, Any]) -> None:
if (prefix + "weight") in state_dict:
state_dict[prefix + "weight"] = tensor.t()
else:
raise RuntimeError(f"Cannot find {(prefix + 'weight')} in the state_dict")
def transpose_parallel_linear_layer(parallel_layer):
"""
This function clones and transposes a ColumnParallelLinear or RowParallelLinear
The attributes are also cloned and partition_dim is updated
"""
orig_attrs = vars(parallel_layer)
new_layer = torch.nn.Parameter(parallel_layer.clone().T, requires_grad=False)
new_layer.__dict__.update(orig_attrs)
# flip the partition_dim from 0->1 or 1->0
setattr(new_layer, "partition_dim", 1 - getattr(new_layer, "partition_dim"))
setattr(new_layer, "get_tensor_from_state_dict", _get_weight_from_state_dict)
setattr(new_layer, "set_tensor_to_state_dict", _set_weight_to_state_dict)
return new_layer
def pad_to_128_multiple(x, dim):
# Strided padding for unsharded weight, so after sharding
# each rank will have dense padding at the end.
# Eg orig shape = [16384, 53248], with dim = 1
# We reshape to [16384, 128, 416] (TP_degree = 128)
# Then pad to [16384, 128, 512].
# Then collapse the original dim [16384, 65536].
TP_DEGREE = get_tensor_model_parallel_group().size()
orig_shape = x.shape
new_shape = list(x.shape)
new_shape[dim] = orig_shape[dim] // TP_DEGREE
new_shape.insert(dim, TP_DEGREE)
x = x.reshape(new_shape)
dim += 1
padding_length = get_padding_length(x.shape[dim], 128)
dimlist = [0] * (len(x.shape) * 2)
dimlist[dim * 2] = padding_length
padded = torch.nn.functional.pad(x, tuple(dimlist[::-1]))
new_padded_shape = list(orig_shape)
new_padded_shape[dim - 1] = -1
padded = padded.reshape(new_padded_shape)
return padded
def move_heads_front(tensor: Tensor, bsz: int, seq_len: int, num_head: int, head_dim: int, layernorm=None) -> Tensor:
"""Reshape input tensor: BSHD -> BHSD, and apply layer normalization if layernorm is specified"""
tensor = tensor.view(bsz, seq_len, num_head, head_dim)
if layernorm:
tensor = layernorm(tensor)
return tensor.transpose(1, 2).contiguous()
def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _rotate_half(x) -> Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search, specifically for Llama3.2 MM PyTorch Implementation
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
return freqs
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1) -> Tuple[Tensor, Tensor]:
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def apply_rotary_polar_compatible(query, key, freqs_cis):
freqs_cis_real = freqs_cis.cos().unsqueeze(2)
freqs_cis_imag = freqs_cis.sin().unsqueeze(2)
def rotate(input):
real = input[..., ::2]
imag = input[..., 1::2]
# For complex multiplication
# (a + ib) * (c + id) = (ac - bd) + i(ad + bc)
# ac - bd
rot_real = (real * freqs_cis_real) - (imag * freqs_cis_imag)
# ad + bc
rot_imag = (real * freqs_cis_imag) + (freqs_cis_real * imag)
return torch.cat([rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)], dim=-1).reshape(input.shape)
query_rot = rotate(query)
key_rot = rotate(key)
return query_rot.type_as(query), key_rot.type_as(key)
def manual_softmax(prior_scores, active_scores, is_speculation) -> Tuple[Tensor, Tensor]:
"""
simple softmax computation: denominator is the sum of exp over all vocab and only need compute numerator (exp)
"""
max_score = torch.max(prior_scores, dim=-1, keepdim=True)[0]
max_active_score = torch.max(active_scores, dim=-1, keepdim=True)[0]
max_score = (
torch.maximum(max_score, max_active_score) if is_speculation else torch.maximum(max_score, active_scores)
)
exp_prior = torch.exp(prior_scores - max_score)
exp_active = torch.exp(active_scores - max_score)
denominator = exp_prior.sum(dim=-1, keepdim=True) + exp_active.sum(dim=-1, keepdim=True)
softmax_prior = exp_prior / denominator
softmax_active = exp_active / denominator
return softmax_prior, softmax_active
def distributed_softmax(prior_scores, active_scores) -> Tuple[Tensor, Tensor]:
"""
compute partial softmax and then gather and correct final softmax.
"""
# find local max
max_score = torch.max(prior_scores, dim=-1, keepdim=True)[0]
max_active_score = torch.max(active_scores, dim=-1, keepdim=True)[0]
local_max_score = torch.maximum(max_score, max_active_score)
exp_prior = torch.exp(prior_scores - local_max_score)
exp_active = torch.exp(active_scores - local_max_score)
denominator = exp_prior.sum(dim=-1, keepdim=True) + exp_active.sum(dim=-1, keepdim=True)
# collect for global max and exp sum (denominator)
groups = get_kv_shared_group(as_list=True)
gather_payload = torch.cat((local_max_score, denominator), dim=0)
gathered_res = xm.all_gather(gather_payload, dim=-1, groups=groups, pin_layout=False)
gathered_max, gathered_denom = torch.chunk(gathered_res, 2, dim=0)
global_max = torch.max(gathered_max, dim=-1, keepdim=True)[0]
# softmax correction
scaling_factor = torch.exp(gathered_max - global_max.expand(gathered_max.shape))
corrected_denominator = torch.multiply(scaling_factor, gathered_denom)
corrected_denominator = torch.sum(corrected_denominator, dim=-1, keepdim=True)
corrected_exp_prior = torch.exp(prior_scores - global_max)
corrected_exp_active = torch.exp(active_scores - global_max)
softmax_prior = corrected_exp_prior / corrected_denominator
softmax_active = corrected_exp_active / corrected_denominator
return softmax_prior, softmax_active
class RotaryEmbedding(nn.Module):
"""
Adapted from Llama 4.0 impl https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models
/llama/modeling_llama.py#L96-L145
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Utility functions to create attention mask
def create_block_diagonal_attn_mask(
query_lens: torch.Tensor,
key_lens: torch.Tensor,
max_query_len: torch.Tensor,
max_key_len: torch.Tensor,
):
"""
Return a block diagonal attention mask which can be used by chunked
prefill.
This function is written in a way that it can be traced, so it can
be used inside the NeuronDecoderModel class.
Example:
query_lens = [2,3,1,0]
key_lens = [4,5,4,0]
max_query_len = 8
max_key_len = 16
mask = [
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # At position 3 attend to 1st sequence
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # At position 4 attend to 1st sequence
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # At position 3 attend to 2nd sequence
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # At position 4 attend to 2nd sequence
[0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], # At position 5 attend to 2nd sequence
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0], # At position 3 attend to 3rd sequence
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # padding
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # padding
]
Args:
query_lens: a list of query lengths for each sequence
key_lens: a list of key lengths for each sequence
max_query_len: the max value of the sum of query lengths
max_key_len: the max value of the sum of key lengths
Return:
mask: the causal attention mask for chunked prefill
"""
batch_size = query_lens.shape[0]
row_idx = torch.arange(max_query_len, dtype=torch.int).reshape(-1, 1)
col_idx = torch.arange(max_key_len, dtype=torch.int).reshape(1, -1)
q_cumsum = torch.cumsum(query_lens, dim=0)
q_cumsum = torch.cat([torch.tensor(0).reshape(1), q_cumsum])
k_cumsum = torch.cumsum(key_lens, dim=0)
k_cumsum = torch.cat([torch.tensor(0).reshape(1), k_cumsum])
mask = torch.zeros(max_query_len, max_key_len, dtype=torch.bool)
for seq_id in range(batch_size):
ri = q_cumsum[seq_id] # row index
ci = k_cumsum[seq_id] # column index
nr = query_lens[seq_id] # number of rows
nc = key_lens[seq_id] # number of columns
offset = ci + nc - ri - nr
# upper right triangle is set to false
diagonal_mask = (row_idx - col_idx + offset) >= 0
left_mask = col_idx >= ci
top_mask = row_idx >= ri
bottom_mask = row_idx < ri + nr
mask_per_seq = diagonal_mask & left_mask & top_mask & bottom_mask
mask = mask | mask_per_seq
return mask