optimum/neuron/models/inference/mixtral/modeling_mixtral.py (184 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/mixtral/modeling_mixtral.py """PyTorch Mixtral model for NXD inference.""" import gc import warnings from typing import Optional, Tuple, Union import torch # Try except for the compatibility with older compiler version from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, ParallelEmbedding from torch import nn from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput from transformers.models.mixtral.modeling_mixtral import MixtralConfig from ..backend.config import NxDNeuronConfig from ..backend.modules.attention.attention_base import NeuronAttentionBase from ..backend.modules.attention.utils import RotaryEmbedding from ..backend.modules.custom_calls import CustomRMSNorm from ..backend.modules.decoder import NxDDecoderModel, NxDModelForCausalLM from ..backend.modules.moe import initialize_moe_module SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] def convert_mixtral_to_neuron_state_dict(neuron_state_dict, config, neuron_config): """ Helper function which returns the model weights from the mixtral model in a state dictionary compatible with the stucture of the neuron MoE model. """ assert neuron_config.glu_mlp is True, "Only GLU MLP is supported for Mixtral Top-K model" for l in range(config.num_hidden_layers): # noqa: E741 # Copy router weights neuron_state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = ( neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"].detach().clone() ) del neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"] intermediate_size, hidden_size = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].shape device = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].device dtype = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].dtype # copy the MLP parameters gate_up_proj = torch.empty( config.num_local_experts, hidden_size, 2 * intermediate_size, dtype=dtype, device=device, ) for e in range(config.num_local_experts): # Copy gate_proj and up_proj after concatenation gate_proj_weights = ( neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"].T.detach().clone() ) up_proj_weights = ( neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"].T.detach().clone() ) gate_up_proj_slice = torch.narrow(gate_up_proj, 0, e, 1) gate_proj_slice = torch.narrow(gate_up_proj_slice, 2, 0, intermediate_size) gate_proj_slice.copy_(gate_proj_weights) up_proj_slice = torch.narrow(gate_up_proj_slice, 2, intermediate_size, intermediate_size) up_proj_slice.copy_(up_proj_weights) del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"] del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"] neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj down_proj = torch.empty( config.num_local_experts, intermediate_size, hidden_size, dtype=dtype, device=device, ) for e in range(config.num_local_experts): # Copy down_proj down_proj_weights = ( neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"].T.detach().clone() ) down_proj_slice = torch.narrow(down_proj, 0, e, 1) down_proj_slice.copy_(down_proj_weights) del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"] neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj gc.collect() return neuron_state_dict def get_rmsnorm_cls(neuron_config): # Initialize to the appropriate implementation of RMSNorm # If infer on NXD -> CustomRMSNorm # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) return CustomRMSNorm class NeuronMixtralAttention(NeuronAttentionBase): def __init__(self, config: MixtralConfig, neuron_config: NxDNeuronConfig): super().__init__(config, neuron_config) self.tp_degree = parallel_state.get_tensor_model_parallel_size() head_dim = config.hidden_size // config.num_attention_heads self.rotary_emb = RotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) class NeuronMixtralDecoderLayer(nn.Module): """ Just replace the attention with the NXD version, and MLP with the NXD version """ def __init__(self, config: MixtralConfig, neuron_config: NxDNeuronConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = NeuronMixtralAttention(config, neuron_config) self.mlp = initialize_moe_module( neuron_config=neuron_config, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = get_rmsnorm_cls(neuron_config)( config.hidden_size, eps=config.rms_norm_eps, ) self.post_attention_layernorm = get_rmsnorm_cls(neuron_config)( config.hidden_size, eps=config.rms_norm_eps, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. position_ids (`torch.FloatTensor`, *optional*): position ids of size `(batch_size, sequence_length)`. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, **kwargs, ) hidden_states = residual + hidden_states # MoE residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states)[0] hidden_states = residual + hidden_states outputs = (hidden_states, present_key_value, cos_cache, sin_cache) return outputs class NxDMixtralModel(NxDDecoderModel): """ NeuronMixtralModel extends the MixtralModel to be traceable. The forward function of this class is traced. """ def __init__(self, config: MixtralConfig, neuron_config: NxDNeuronConfig): super().__init__(config, neuron_config) self.embed_tokens = ParallelEmbedding( config.vocab_size, config.hidden_size, config.pad_token_id, dtype=neuron_config.torch_dtype, shard_across_embedding=True, ) self.layers = nn.ModuleList( [ NeuronMixtralDecoderLayer(config, neuron_config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = get_rmsnorm_cls(neuron_config)(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = ColumnParallelLinear( config.hidden_size, config.vocab_size, gather_output=not neuron_config.on_device_sampling, bias=False, ) class MixtralNxDModelForCausalLM(NxDModelForCausalLM): """ This class can be used as MixtralForCausalLM """ _model_cls = NxDMixtralModel @classmethod def get_neuron_config_cls(cls): return NxDNeuronConfig @staticmethod def convert_hf_to_neuron_state_dict( state_dict: dict, config: MixtralConfig, neuron_config: NxDNeuronConfig ) -> dict: return convert_mixtral_to_neuron_state_dict(state_dict, config, neuron_config) @classmethod def get_compiler_args(cls, neuron_config: NxDNeuronConfig) -> str: compiler_args = "--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1" # Add flags for cc-overlap compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'" compiler_args += " --auto-cast=none" # Enable vector-offset DGE compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" return compiler_args @classmethod def _get_neuron_config( cls, checkpoint_id: str, checkpoint_revision: str, batch_size: int, sequence_length: int, tensor_parallel_size: int, auto_cast_type: str, ): return NxDNeuronConfig( checkpoint_id=checkpoint_id, checkpoint_revision=checkpoint_revision, batch_size=batch_size, sequence_length=sequence_length, tp_degree=tensor_parallel_size, torch_dtype=auto_cast_type, )