optimum/neuron/models/inference/llama/modeling_llama.py (396 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/llama/modeling_llama.py """PyTorch LLaMA model for NXD inference.""" import gc import logging import math import warnings from typing import Optional, Tuple, Type import torch from neuronx_distributed.parallel_layers.layers import ( ColumnParallelLinear, ParallelEmbedding, RowParallelLinear, ) from neuronx_distributed.parallel_layers.mappings import ( gather_from_sequence_parallel_region, reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, ) from neuronxcc.nki._private_kernels.mlp import ( mlp_fused_add_isa_kernel, mlp_isa_kernel, ) from neuronxcc.nki.language import nc from torch import nn from torch_neuronx.xla_impl.ops import nki_jit from transformers.activations import ACT2FN from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding from ..backend.config import NxDNeuronConfig # noqa: E402 from ..backend.modules.attention.attention_base import NeuronAttentionBase from ..backend.modules.attention.utils import ( RotaryEmbedding, transpose_parallel_linear_layer, ) from ..backend.modules.custom_calls import CustomRMSNorm from ..backend.modules.decoder import NxDDecoderModel, NxDModelForCausalLM logger = logging.getLogger("Neuron") def convert_state_dict_to_fused_qkv(llama_state_dict, cfg: LlamaConfig): """ This function concats the qkv weights to a Wqkv weight for fusedqkv, and deletes the qkv weights. """ for l in range(cfg.num_hidden_layers): # noqa: E741 llama_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( [ llama_state_dict[f"layers.{l}.self_attn.q_proj.weight"], llama_state_dict[f"layers.{l}.self_attn.k_proj.weight"], llama_state_dict[f"layers.{l}.self_attn.v_proj.weight"], ], ) del llama_state_dict[f"layers.{l}.self_attn.q_proj.weight"] del llama_state_dict[f"layers.{l}.self_attn.k_proj.weight"] del llama_state_dict[f"layers.{l}.self_attn.v_proj.weight"] gc.collect() return llama_state_dict class NeuronLlamaMLP(nn.Module): """ This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers """ def __init__(self, config: LlamaConfig, neuron_config: NxDNeuronConfig): super().__init__() self.tp_degree = neuron_config.tp_degree self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.act_fn = ACT2FN[config.hidden_act] self.sequence_parallel_enabled = getattr(neuron_config, "sequence_parallel_enabled", False) self.sequence_dimension = 1 if self.sequence_parallel_enabled else None self.rms_norm_eps = config.rms_norm_eps self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled self.logical_nc_config = neuron_config.logical_nc_config mlp_bias = getattr(config, "mlp_bias", False) self.gate_proj = ColumnParallelLinear( self.hidden_size, self.intermediate_size, bias=mlp_bias, gather_output=False, dtype=neuron_config.torch_dtype, pad=True, sequence_parallel_enabled=False, sequence_dimension=None, ) self.up_proj = ColumnParallelLinear( self.hidden_size, self.intermediate_size, bias=mlp_bias, gather_output=False, dtype=neuron_config.torch_dtype, pad=True, sequence_parallel_enabled=False, sequence_dimension=None, ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=mlp_bias, input_is_parallel=True, dtype=neuron_config.torch_dtype, pad=True, sequence_parallel_enabled=self.sequence_parallel_enabled, sequence_dimension=self.sequence_dimension, reduce_dtype=neuron_config.rpl_reduce_dtype, ) if self.mlp_kernel_enabled: # Transpose the weights to the layout expected by kernels self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight) self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight) self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight) def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual): fused_residual = residual is not None logger.debug( f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_nc_config={self.logical_nc_config}" ) # Choose which kernel to call if fused_residual: assert not self.sequence_parallel_enabled, ( "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" ) # Using fused residual add _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) else: _mlp_fwd_call = nki_jit()(mlp_isa_kernel) if self.sequence_parallel_enabled: x = gather_from_sequence_parallel_region(x, self.sequence_dimension) # Build output tensor output_tensor_seqlen = x.shape[1] if fused_residual: # seqlen dim is doubled to store the residual add output output_tensor_seqlen *= 2 output_tensor = torch.zeros( size=( x.shape[0], # batch size output_tensor_seqlen, self.hidden_size, # hidden size ), dtype=x.dtype, device=x.device, ) # Grab weights # all weights of the layers are stored in (out, in) shape # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] ln_w = rmsnorm.weight.unsqueeze(0) gate_w = self.gate_proj.weight.data up_w = self.up_proj.weight.data down_w = self.down_proj.weight.data grid = (nc(self.logical_nc_config),) if fused_residual: _mlp_fwd_call[grid]( x, # attn_output residual, # hidden ln_w, # ln_w gate_w, # gate_w up_w, # up_w down_w, # down_w output_tensor, # out fused_rmsnorm=fused_rmsnorm, eps=self.rms_norm_eps, kernel_name="MLP", store_add=True, ) original_seqlen = x.shape[1] residual = output_tensor[:, original_seqlen:, :] output_tensor = output_tensor[:, :original_seqlen, :] else: _mlp_fwd_call[grid]( x, # hidden # should be fine to pass gamma is as a dummy even if not using fused rmsnorm ln_w, gate_w, up_w, down_w, output_tensor, # out # Run RMSNorm inside the kernel if NOT using SP rmsnorm fused_rmsnorm=fused_rmsnorm, eps=self.rms_norm_eps, kernel_name="MLP", ) residual = None # All-reduce or reduce-scatter, depending on whether SP is enabled if self.sequence_parallel_enabled: output_tensor = reduce_scatter_to_sequence_parallel_region(output_tensor, self.sequence_dimension) else: output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) logger.debug(f"MLP output shape {output_tensor.shape}") return (output_tensor, residual) def _native_mlp(self, x, rmsnorm): logger.debug("MLP: native compiler") # all-gather is done here instead of CPL layers to # avoid 2 all-gathers from up and gate projections if self.sequence_parallel_enabled: x = gather_from_sequence_parallel_region(x, self.sequence_dimension) gate_proj_output = self.gate_proj(x) up_proj_output = self.up_proj(x) down_proj_input = self.act_fn(gate_proj_output) * up_proj_output output = self.down_proj(down_proj_input) logger.debug(f"MLP output shape {output.shape}") return output def forward(self, x, rmsnorm=None, residual=None): """ If residual is passed in, will fuse its add into the MLP kernel Returns a tuple of (output, residual), where residual is the output of the residual add """ if self.mlp_kernel_enabled: fused_rmsnorm = not self.sequence_parallel_enabled # MLP kernel return self._kernel_enabled_mlp(x, fused_rmsnorm, rmsnorm, residual) else: # No kernel return (self._native_mlp(x, rmsnorm), None) class NeuronLlamaAttention(NeuronAttentionBase): """ The only difference with the NeuronAttentionBase is the definition of the Llama rotary embedding """ def __init__( self, config: LlamaConfig, neuron_config: NxDNeuronConfig, qkv_proj_bias: Optional[bool] = False, o_proj_bias: Optional[bool] = False, qk_scale: Optional[float] = None, ): super().__init__( config, neuron_config, qkv_proj_bias=qkv_proj_bias, o_proj_bias=o_proj_bias, qk_scale=qk_scale ) head_dim = config.hidden_size // config.num_attention_heads if not hasattr(config, "rope_scaling") or config.rope_scaling is None: self.rotary_emb = RotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) else: rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", None)) if rope_type == "llama3": self.rotary_emb = Llama3RotaryEmbedding( dim=head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, factor=config.rope_scaling["factor"], low_freq_factor=config.rope_scaling["low_freq_factor"], high_freq_factor=config.rope_scaling["high_freq_factor"], original_max_position_embeddings=config.rope_scaling["original_max_position_embeddings"], ) else: # LlamaRotaryEmbedding automatically chooses the correct scaling type from config. # Warning: The HF implementation may have precision issues when run on Neuron. # We include it here for compatibility with other scaling types. self.rotary_emb = LlamaRotaryEmbedding(config) # TODO: Modularize RotaryEmbedding. See how HF transformers does it in 4.43. class Llama3RotaryEmbedding(nn.Module): """ Adapted from Llama 4.43 impl * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/llama/modeling_llama.py#L78 * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/modeling_rope_utils.py#L345 This implementation ensures inv_freq is calculated and stored in fp32. """ def __init__( self, dim, max_position_embeddings=131072, base=500000.0, factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192, ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.factor = factor self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.old_context_len = original_max_position_embeddings self.register_buffer("inv_freq", None, persistent=False) @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] if self.inv_freq is None: inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) ) low_freq_wavelen = self.old_context_len / self.low_freq_factor high_freq_wavelen = self.old_context_len / self.high_freq_factor new_freqs = [] for freq in inv_freq: wavelen = 2 * math.pi / freq if wavelen < high_freq_wavelen: new_freqs.append(freq) elif wavelen > low_freq_wavelen: new_freqs.append(freq / self.factor) else: assert low_freq_wavelen != high_freq_wavelen smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( self.high_freq_factor - self.low_freq_factor ) new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) self.inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() with torch.autocast(device_type=x.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class NeuronLlamaDecoderLayer(nn.Module): """ Just replace the attention with the NXD version, and MLP with the NXD version """ def __init__(self, config: LlamaConfig, neuron_config: NxDNeuronConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = NeuronLlamaAttention(config, neuron_config) self.mlp = NeuronLlamaMLP(config, neuron_config) logger.debug( f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" ) self.input_layernorm = CustomRMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) self.post_attention_layernorm = CustomRMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) self.qkv_kernel_enabled = neuron_config.qkv_kernel_enabled self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled self.mlp_kernel_fuse_residual_add = neuron_config.mlp_kernel_fuse_residual_add self.sequence_parallel_enabled = neuron_config.sequence_parallel_enabled self.config = config def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm: hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, rmsnorm=self.input_layernorm, **kwargs, ) if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: assert not self.sequence_parallel_enabled, ( "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" ) # First residual add handled in the MLP kernel hidden_states, residual = self.mlp( hidden_states, rmsnorm=self.post_attention_layernorm, residual=residual, ) else: hidden_states = residual + hidden_states residual = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp( hidden_states, rmsnorm=self.post_attention_layernorm, ) hidden_states = residual + hidden_states outputs = (hidden_states, present_key_value, cos_cache, sin_cache) return outputs class NxDLlamaModel(NxDDecoderModel): """ The neuron version of the LlamaModel """ def __init__(self, config: LlamaConfig, neuron_config: NxDNeuronConfig): super().__init__(config, neuron_config) self.embed_tokens = ParallelEmbedding( config.vocab_size, config.hidden_size, config.pad_token_id, dtype=neuron_config.torch_dtype, shard_across_embedding=not neuron_config.vocab_parallel, sequence_parallel_enabled=False, pad=True, use_spmd_rank=neuron_config.vocab_parallel, ) self.lm_head = ColumnParallelLinear( config.hidden_size, config.vocab_size, gather_output=not neuron_config.on_device_sampling, bias=False, pad=True, ) self.layers = nn.ModuleList( [NeuronLlamaDecoderLayer(config, neuron_config) for _ in range(config.num_hidden_layers)] ) self.norm = CustomRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class LlamaNxDModelForCausalLM(NxDModelForCausalLM): """ This class extends LlamaForCausalLM create traceable blocks for Neuron. Args: LlamaForCausalLM (_type_): _description_ """ _model_cls = NxDLlamaModel @staticmethod def convert_hf_to_neuron_state_dict(state_dict: dict, config: LlamaConfig, neuron_config: NxDNeuronConfig) -> dict: """This function should be over-ridden in child classes as needed""" if neuron_config.fused_qkv: state_dict = convert_state_dict_to_fused_qkv(state_dict, config) if neuron_config.vocab_parallel: # TODO: this hack can be removed after replication_id is ready to use state_dict["embed_tokens.rank_util.rank"] = torch.arange(0, neuron_config.local_ranks_size) # to facilitate rank usage in attention num_layers = config.num_hidden_layers tp_degree = neuron_config.tp_degree for i in range(num_layers): state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) # to facilitate rank usage in base model state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) return state_dict @staticmethod def update_state_dict_for_tied_weights(state_dict): state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() @classmethod def get_neuron_config_cls(cls) -> Type[NxDNeuronConfig]: return NxDNeuronConfig @classmethod def _get_neuron_config( cls, checkpoint_id: str, checkpoint_revision: str, batch_size: int, sequence_length: int, tensor_parallel_size: int, auto_cast_type: str, ): continuous_batching = (batch_size > 1) if batch_size else False on_device_sampling = True if continuous_batching and tensor_parallel_size == 2: # Neuron SDK 2.22 bug: the model will crash when continuous_batching is enabled # if the tensor parallel size is 2 and on_device_sampling is enabled. warnings.warn( "Activating continuous batching but disabling on-device sampling because of a neuron runtime bug when tensor parallel size is 2." ) on_device_sampling = False return NxDNeuronConfig( checkpoint_id=checkpoint_id, checkpoint_revision=checkpoint_revision, batch_size=batch_size, sequence_length=sequence_length, tp_degree=tensor_parallel_size, torch_dtype=auto_cast_type, on_device_sampling=on_device_sampling, fused_qkv=True, continuous_batching=continuous_batching, )