optimum/neuron/models/training/granite/modeling_granite.py (236 lines of code) (raw):
# coding=utf-8
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.loss.loss_utils import ForCausalLMLoss
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.granite.configuration_granite import GraniteConfig
from transformers.processing_utils import Unpack
from transformers.utils import LossKwargs, can_return_tuple, logging
from ....utils import is_neuronx_distributed_available, is_torch_xla_available
from ..config import TrainingNeuronConfig
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
if is_torch_xla_available():
from torch_xla.utils.checkpoint import checkpoint
if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers.mappings import (
gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region,
)
# Wrap the gather and scatter functions to ensure they are properly traced by `torch.fx`.
gather_from_sequence_parallel_region = torch.fx.wrap(gather_from_sequence_parallel_region)
scatter_to_sequence_parallel_region = torch.fx.wrap(scatter_to_sequence_parallel_region)
logger = logging.get_logger(__name__)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class GraniteAttention(LlamaAttention):
def __init__(self, config: GraniteConfig, trn_config: TrainingNeuronConfig, layer_idx: int):
super().__init__(config, trn_config, layer_idx)
self.scaling = config.attention_multiplier
class GraniteDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: GraniteConfig, trn_config: TrainingNeuronConfig, layer_idx: int):
super().__init__(config, trn_config, layer_idx)
self.residual_multiplier = config.residual_multiplier
self.self_attn = GraniteAttention(config=config, trn_config=trn_config, layer_idx=layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states * self.residual_multiplier
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class GraniteModel(LlamaModel):
config_class = GraniteConfig
def __init__(self, config: GraniteConfig, trn_config: TrainingNeuronConfig):
LlamaPreTrainedModel.__init__(self, config)
self.embedding_multiplier = config.embedding_multiplier
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.trn_config = trn_config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[GraniteDecoderLayer(config, trn_config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, sequence_parallel_enabled=trn_config.sequence_parallel_enabled
)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = self.trn_config.gradient_checkpointing
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.trn_config.sequence_parallel_enabled:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
current_length = (
inputs_embeds.size(0) * self.trn_config.tensor_parallel_size
if self.trn_config.sequence_parallel_enabled
else inputs_embeds.size(1)
)
cache_position = torch.arange(0, current_length, device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if self.trn_config.recompute_causal_mask:
causal_mask = None
else:
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = checkpoint(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
output_attentions,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GraniteForCausalLM(LlamaForCausalLM):
config_class = GraniteConfig
SUPPORTS_PIPELINE_PARALLELISM = False
PIPELINE_TRANSFORMER_LAYER_CLS = GraniteDecoderLayer
PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask", "labels"]
PIPELINE_LEAF_MODULE_CLASSE_NAMES = ["LlamaRMSNorm", "LlamaRotaryEmbedding"]
def __init__(self, config, trn_config: TrainingNeuronConfig):
LlamaPreTrainedModel.__init__(self, config)
self.trn_config = trn_config
self.model = GraniteModel(config, trn_config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits / self.config.logits_scaling # main diff with Llama
if self.trn_config.sequence_parallel_enabled:
logits = gather_from_sequence_parallel_region(logits)
logits = logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
loss = ForCausalLMLoss(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)