optimum/neuron/models/inference/granite/modeling_granite.py (118 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/llama/modeling_llama.py """PyTorch Granite model for NXD inference.""" import logging from typing import Any, Optional, Tuple import torch from neuronx_distributed.parallel_layers.layers import ( ColumnParallelLinear, ParallelEmbedding, ) from torch import nn from transformers.models.granite.configuration_granite import GraniteConfig from ..backend.config import NxDNeuronConfig from ..backend.modules.custom_calls import CustomRMSNorm from ..backend.modules.decoder import NxDDecoderModel from ..llama.modeling_llama import LlamaNxDModelForCausalLM, NeuronLlamaAttention, NeuronLlamaMLP logger = logging.getLogger("Neuron") class NeuronGraniteDecoderLayer(nn.Module): """A Granite specific decoder layer with: - custom scaling factor applied to the qk product in attention, - custom scaling factors applied to attention and mlp outputs """ def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = NeuronLlamaAttention(config, neuron_config, qk_scale=config.attention_multiplier) self.mlp = NeuronLlamaMLP(config, neuron_config) logger.debug( f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" ) self.input_layernorm = CustomRMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) self.post_attention_layernorm = CustomRMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) self.qkv_kernel_enabled = neuron_config.qkv_kernel_enabled self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled self.mlp_kernel_fuse_residual_add = neuron_config.mlp_kernel_fuse_residual_add self.sequence_parallel_enabled = neuron_config.sequence_parallel_enabled self.config = config def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm: hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, rmsnorm=self.input_layernorm, **kwargs, ) # Granite specific: attention output is multiplied by residual multiplier hidden_states = hidden_states * self.config.residual_multiplier if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: assert not self.sequence_parallel_enabled, ( "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" ) # First residual add handled in the MLP kernel hidden_states, residual = self.mlp( hidden_states, rmsnorm=self.post_attention_layernorm, residual=residual, ) else: hidden_states = residual + hidden_states residual = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp( hidden_states, rmsnorm=self.post_attention_layernorm, ) # Granite specific: MLP output is multiplied by residual_multiplier hidden_states = hidden_states * self.config.residual_multiplier hidden_states = residual + hidden_states outputs = (hidden_states, present_key_value, cos_cache, sin_cache) return outputs class NxDGraniteEmbedding(ParallelEmbedding): """A custom neuron parallel embedding layer with scaled outputs""" def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig): super().__init__( config.vocab_size, config.hidden_size, config.pad_token_id, dtype=neuron_config.torch_dtype, shard_across_embedding=not neuron_config.vocab_parallel, sequence_parallel_enabled=False, pad=True, use_spmd_rank=neuron_config.vocab_parallel, ) self.config = config def forward(self, input_: torch.Tensor) -> torch.Tensor: # Granite specific: embeddings are multiplied by custom scale factor embeddings = super().forward(input_) return embeddings * self.config.embedding_multiplier class NxDGraniteHead(ColumnParallelLinear): """A custom lm head neuron column parallel linear layer with scaled logits""" def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig): super().__init__( config.hidden_size, config.vocab_size, gather_output=not neuron_config.on_device_sampling, bias=False, pad=True, ) self.config = config def forward(self, input: torch.Tensor, *_: Any) -> torch.Tensor: logits = super().forward(input) # Granite specific: divide logits by custom scaling factor return logits / self.config.logits_scaling class NxDGraniteModel(NxDDecoderModel): """ The differences with the standard neuron decoder are: - the use of scaled embeddings, - the used of the custom granite layers with scaled attention and mlp, - the use of scaled head logits. """ def __init__(self, config: GraniteConfig, neuron_config: NxDNeuronConfig): super().__init__(config, neuron_config) self.embed_tokens = NxDGraniteEmbedding(config, neuron_config) self.lm_head = NxDGraniteHead(config, neuron_config) self.layers = nn.ModuleList( [NeuronGraniteDecoderLayer(config, neuron_config) for _ in range(config.num_hidden_layers)] ) self.norm = CustomRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class GraniteNxDModelForCausalLM(LlamaNxDModelForCausalLM): """ This class extends LlamaForCausalLM create traceable blocks for Neuron. Args: LlamaForCausalLM (_type_): _description_ """ _model_cls = NxDGraniteModel