optimum/neuron/models/inference/mixtral/modeling_mixtral.py (184 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/mixtral/modeling_mixtral.py
"""PyTorch Mixtral model for NXD inference."""
import gc
import warnings
from typing import Optional, Tuple, Union
import torch
# Try except for the compatibility with older compiler version
from neuronx_distributed.parallel_layers import parallel_state
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, ParallelEmbedding
from torch import nn
from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput
from transformers.models.mixtral.modeling_mixtral import MixtralConfig
from ..backend.config import NxDNeuronConfig
from ..backend.modules.attention.attention_base import NeuronAttentionBase
from ..backend.modules.attention.utils import RotaryEmbedding
from ..backend.modules.custom_calls import CustomRMSNorm
from ..backend.modules.decoder import NxDDecoderModel, NxDModelForCausalLM
from ..backend.modules.moe import initialize_moe_module
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
def convert_mixtral_to_neuron_state_dict(neuron_state_dict, config, neuron_config):
"""
Helper function which returns the model weights from the mixtral model in a state dictionary compatible with the stucture of the neuron MoE model.
"""
assert neuron_config.glu_mlp is True, "Only GLU MLP is supported for Mixtral Top-K model"
for l in range(config.num_hidden_layers): # noqa: E741
# Copy router weights
neuron_state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"].detach().clone()
)
del neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"]
intermediate_size, hidden_size = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].shape
device = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].device
dtype = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].dtype
# copy the MLP parameters
gate_up_proj = torch.empty(
config.num_local_experts,
hidden_size,
2 * intermediate_size,
dtype=dtype,
device=device,
)
for e in range(config.num_local_experts):
# Copy gate_proj and up_proj after concatenation
gate_proj_weights = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"].T.detach().clone()
)
up_proj_weights = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"].T.detach().clone()
)
gate_up_proj_slice = torch.narrow(gate_up_proj, 0, e, 1)
gate_proj_slice = torch.narrow(gate_up_proj_slice, 2, 0, intermediate_size)
gate_proj_slice.copy_(gate_proj_weights)
up_proj_slice = torch.narrow(gate_up_proj_slice, 2, intermediate_size, intermediate_size)
up_proj_slice.copy_(up_proj_weights)
del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"]
del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"]
neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj
down_proj = torch.empty(
config.num_local_experts,
intermediate_size,
hidden_size,
dtype=dtype,
device=device,
)
for e in range(config.num_local_experts):
# Copy down_proj
down_proj_weights = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"].T.detach().clone()
)
down_proj_slice = torch.narrow(down_proj, 0, e, 1)
down_proj_slice.copy_(down_proj_weights)
del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"]
neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj
gc.collect()
return neuron_state_dict
def get_rmsnorm_cls(neuron_config):
# Initialize to the appropriate implementation of RMSNorm
# If infer on NXD -> CustomRMSNorm
# If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU)
return CustomRMSNorm
class NeuronMixtralAttention(NeuronAttentionBase):
def __init__(self, config: MixtralConfig, neuron_config: NxDNeuronConfig):
super().__init__(config, neuron_config)
self.tp_degree = parallel_state.get_tensor_model_parallel_size()
head_dim = config.hidden_size // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
class NeuronMixtralDecoderLayer(nn.Module):
"""
Just replace the attention with the NXD version, and MLP with the NXD version
"""
def __init__(self, config: MixtralConfig, neuron_config: NxDNeuronConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = NeuronMixtralAttention(config, neuron_config)
self.mlp = initialize_moe_module(
neuron_config=neuron_config,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = get_rmsnorm_cls(neuron_config)(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = get_rmsnorm_cls(neuron_config)(
config.hidden_size,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
position_ids (`torch.FloatTensor`, *optional*):
position ids of size `(batch_size, sequence_length)`.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
**kwargs,
)
hidden_states = residual + hidden_states
# MoE
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)[0]
hidden_states = residual + hidden_states
outputs = (hidden_states, present_key_value, cos_cache, sin_cache)
return outputs
class NxDMixtralModel(NxDDecoderModel):
"""
NeuronMixtralModel extends the MixtralModel to be traceable.
The forward function of this class is traced.
"""
def __init__(self, config: MixtralConfig, neuron_config: NxDNeuronConfig):
super().__init__(config, neuron_config)
self.embed_tokens = ParallelEmbedding(
config.vocab_size,
config.hidden_size,
config.pad_token_id,
dtype=neuron_config.torch_dtype,
shard_across_embedding=True,
)
self.layers = nn.ModuleList(
[
NeuronMixtralDecoderLayer(config, neuron_config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = get_rmsnorm_cls(neuron_config)(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
gather_output=not neuron_config.on_device_sampling,
bias=False,
)
class MixtralNxDModelForCausalLM(NxDModelForCausalLM):
"""
This class can be used as MixtralForCausalLM
"""
_model_cls = NxDMixtralModel
@classmethod
def get_neuron_config_cls(cls):
return NxDNeuronConfig
@staticmethod
def convert_hf_to_neuron_state_dict(
state_dict: dict, config: MixtralConfig, neuron_config: NxDNeuronConfig
) -> dict:
return convert_mixtral_to_neuron_state_dict(state_dict, config, neuron_config)
@classmethod
def get_compiler_args(cls, neuron_config: NxDNeuronConfig) -> str:
compiler_args = "--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1"
# Add flags for cc-overlap
compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'"
compiler_args += " --auto-cast=none"
# Enable vector-offset DGE
compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets"
return compiler_args
@classmethod
def _get_neuron_config(
cls,
checkpoint_id: str,
checkpoint_revision: str,
batch_size: int,
sequence_length: int,
tensor_parallel_size: int,
auto_cast_type: str,
):
return NxDNeuronConfig(
checkpoint_id=checkpoint_id,
checkpoint_revision=checkpoint_revision,
batch_size=batch_size,
sequence_length=sequence_length,
tp_degree=tensor_parallel_size,
torch_dtype=auto_cast_type,
)