optimum/neuron/models/inference/granite/modeling_granite.py (118 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/llama/modeling_llama.py
"""PyTorch Granite model for NXD inference."""
import logging
from typing import Any, Optional, Tuple
import torch
from neuronx_distributed.parallel_layers.layers import (
ColumnParallelLinear,
ParallelEmbedding,
)
from torch import nn
from transformers.models.granite.configuration_granite import GraniteConfig
from ..backend.config import NxDNeuronConfig
from ..backend.modules.custom_calls import CustomRMSNorm
from ..backend.modules.decoder import NxDDecoderModel
from ..llama.modeling_llama import LlamaNxDModelForCausalLM, NeuronLlamaAttention, NeuronLlamaMLP
logger = logging.getLogger("Neuron")
class NeuronGraniteDecoderLayer(nn.Module):
"""A Granite specific decoder layer with:
- custom scaling factor applied to the qk product in attention,
- custom scaling factors applied to attention and mlp outputs
"""
def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = NeuronLlamaAttention(config, neuron_config, qk_scale=config.attention_multiplier)
self.mlp = NeuronLlamaMLP(config, neuron_config)
logger.debug(
f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}"
)
self.input_layernorm = CustomRMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = CustomRMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.qkv_kernel_enabled = neuron_config.qkv_kernel_enabled
self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled
self.mlp_kernel_fuse_residual_add = neuron_config.mlp_kernel_fuse_residual_add
self.sequence_parallel_enabled = neuron_config.sequence_parallel_enabled
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
# RMSNorm (fused with QKV kernel when SP is disabled)
if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
rmsnorm=self.input_layernorm,
**kwargs,
)
# Granite specific: attention output is multiplied by residual multiplier
hidden_states = hidden_states * self.config.residual_multiplier
if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add:
assert not self.sequence_parallel_enabled, (
"mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled"
)
# First residual add handled in the MLP kernel
hidden_states, residual = self.mlp(
hidden_states,
rmsnorm=self.post_attention_layernorm,
residual=residual,
)
else:
hidden_states = residual + hidden_states
residual = hidden_states
# RMSNorm (fused with QKV kernel when SP is disabled)
if not self.mlp_kernel_enabled or self.sequence_parallel_enabled:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, _ = self.mlp(
hidden_states,
rmsnorm=self.post_attention_layernorm,
)
# Granite specific: MLP output is multiplied by residual_multiplier
hidden_states = hidden_states * self.config.residual_multiplier
hidden_states = residual + hidden_states
outputs = (hidden_states, present_key_value, cos_cache, sin_cache)
return outputs
class NxDGraniteEmbedding(ParallelEmbedding):
"""A custom neuron parallel embedding layer with scaled outputs"""
def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig):
super().__init__(
config.vocab_size,
config.hidden_size,
config.pad_token_id,
dtype=neuron_config.torch_dtype,
shard_across_embedding=not neuron_config.vocab_parallel,
sequence_parallel_enabled=False,
pad=True,
use_spmd_rank=neuron_config.vocab_parallel,
)
self.config = config
def forward(self, input_: torch.Tensor) -> torch.Tensor:
# Granite specific: embeddings are multiplied by custom scale factor
embeddings = super().forward(input_)
return embeddings * self.config.embedding_multiplier
class NxDGraniteHead(ColumnParallelLinear):
"""A custom lm head neuron column parallel linear layer with scaled logits"""
def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig):
super().__init__(
config.hidden_size,
config.vocab_size,
gather_output=not neuron_config.on_device_sampling,
bias=False,
pad=True,
)
self.config = config
def forward(self, input: torch.Tensor, *_: Any) -> torch.Tensor:
logits = super().forward(input)
# Granite specific: divide logits by custom scaling factor
return logits / self.config.logits_scaling
class NxDGraniteModel(NxDDecoderModel):
"""
The differences with the standard neuron decoder are:
- the use of scaled embeddings,
- the used of the custom granite layers with scaled attention and mlp,
- the use of scaled head logits.
"""
def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig):
super().__init__(config, neuron_config)
self.embed_tokens = NxDGraniteEmbedding(config, neuron_config)
self.lm_head = NxDGraniteHead(config, neuron_config)
self.layers = nn.ModuleList(
[NeuronGraniteDecoderLayer(config, neuron_config) for _ in range(config.num_hidden_layers)]
)
self.norm = CustomRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class GraniteNxDModelForCausalLM(LlamaNxDModelForCausalLM):
"""
This class extends LlamaForCausalLM create traceable
blocks for Neuron.
Args:
LlamaForCausalLM (_type_): _description_
"""
_model_cls = NxDGraniteModel