optimum/neuron/models/inference/t5/modeling_t5.py (234 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/llama/modeling_llama.py """PyTorch T5 model for NXD inference.""" from typing import Optional import torch from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.parallel_layers.layers import ( BaseParallelLinear, ColumnParallelLinear, ParallelEmbedding, RowParallelLinear, ) from neuronx_distributed.parallel_layers.utils import divide from torch import nn from transformers import T5Config from transformers.activations import ACT2FN from transformers.models.t5.modeling_t5 import ( T5Attention, T5DenseActDense, T5DenseGatedActDense, T5LayerCrossAttention, T5LayerFF, T5LayerNorm, T5LayerSelfAttention, ) from transformers.pytorch_utils import find_pruneable_heads_and_indices """ T5 NxD custom modeling, copied from: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/neuronx_distributed/t5-inference/t5-inference-tutorial.html. """ def prune_linear_layer(layer: BaseParallelLinear, index: torch.LongTensor, dim: int = 0) -> BaseParallelLinear: """ Prune a linear layer to keep only entries in index. Used to remove heads. Args: layer (`BaseParallelLinear`): The layer to prune. index (`torch.LongTensor`): The indices to keep in the layer. dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. Returns: `BaseParallelLinear`: The pruned layer as a new layer with `requires_grad=True`. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: if dim == 1: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = ColumnParallelLinear(new_size[1], new_size[0], bias=layer.bias is not None, gather_output=False).to( layer.weight.device ) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True if layer.bias is not None: new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer class NeuronT5Attention(T5Attention): def __init__( self, config: T5Config, has_relative_attention_bias=False, layer_idx: Optional[int] = None, ): super().__init__(config, has_relative_attention_bias, layer_idx) # Per attention head and per partition values world_size = parallel_state.get_tensor_model_parallel_size() self.num_attention_heads_per_partition = divide(self.n_heads, world_size) self.hidden_size_per_partition = self.num_attention_heads_per_partition * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax self.q = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False, gather_output=False) self.k = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False, gather_output=False) self.v = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False, gather_output=False) self.o = RowParallelLinear(self.inner_dim, self.d_model, bias=False, input_is_parallel=True) if self.has_relative_attention_bias: self.relative_attention_bias = ParallelEmbedding(self.relative_attention_num_buckets, self.n_heads) self.n_heads = self.num_attention_heads_per_partition def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.num_attention_heads_per_partition, self.key_value_proj_dim, self.pruned_heads ) # Prune linear layers self.q = prune_linear_layer(self.q, index) self.k = prune_linear_layer(self.k, index) self.v = prune_linear_layer(self.v, index) self.o = prune_linear_layer(self.o, index, dim=1) # Update hyper params self.num_attention_heads_per_partition = self.num_attention_heads_per_partition - len(heads) self.hidden_size_per_partition = self.key_value_proj_dim * self.num_attention_heads_per_partition self.pruned_heads = self.pruned_heads.union(heads) def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) values = self.relative_attention_bias(relative_position_bucket) # TP tp_rank = parallel_state.get_tensor_model_parallel_rank() values = values[ :, :, tp_rank * self.num_attention_heads_per_partition : (tp_rank + 1) * self.num_attention_heads_per_partition, ] values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False, cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) self.is_decoder = True batch_size, seq_length = hidden_states.shape[:2] # if key_value_states are provided this layer is used as a cross-attention layer for the decoder is_cross_attention = key_value_states is not None query_states = self.q(hidden_states) query_states = query_states.view( batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim ).transpose(1, 2) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.key_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k(current_states) value_states = self.v(current_states) key_states = key_states.view( batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim ).transpose(1, 2) value_states = value_states.view( batch_size, -1, self.num_attention_heads_per_partition, self.key_value_proj_dim ).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: key_length = key_states.shape[-2] # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.num_attention_heads_per_partition, seq_length, key_length), device=scores.device, dtype=scores.dtype, ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias( real_seq_length, key_length, device=scores.device, cache_position=cache_position ) position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) mask[list(self.pruned_heads)] = 0 position_bias_masked = position_bias[:, mask.bool()] else: position_bias_masked = position_bias scores += position_bias_masked # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, -1, self.hidden_size_per_partition) attn_output = self.o(attn_output) outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) return outputs class NeuronT5LayerSelfAttention(T5LayerSelfAttention): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.SelfAttention = NeuronT5Attention( config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx, ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) class NeuronT5LayerCrossAttention(T5LayerCrossAttention): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__(config) self.EncDecAttention = NeuronT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) class NeuronT5DenseActDense(T5DenseActDense): def __init__(self, config: T5Config): super().__init__(config) self.wi = ColumnParallelLinear(config.d_model, config.d_ff, gather_output=False, bias=False) self.wo = RowParallelLinear(config.d_ff, config.d_model, input_is_parallel=True, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN[config.dense_act_fn] class NeuronT5DenseGatedActDense(T5DenseGatedActDense): def __init__(self, config: T5Config): super().__init__(config) self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, gather_output=False, bias=False) self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, gather_output=False, bias=False) self.wo = RowParallelLinear(config.d_ff, config.d_model, input_is_parallel=True, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN[config.dense_act_fn] class NeuronT5LayerFF(T5LayerFF): def __init__(self, config: T5Config): super().__init__(config) if config.is_gated_act: self.DenseReluDense = NeuronT5DenseGatedActDense(config) else: self.DenseReluDense = NeuronT5DenseActDense(config) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def parallelize(model): for index, block in enumerate(model.decoder.block): block.layer[0] = NeuronT5LayerSelfAttention( model.config, has_relative_attention_bias=bool(index == 0), layer_idx=index, ) block.layer[1] = NeuronT5LayerCrossAttention(model.config, layer_idx=index) block.layer[2] = NeuronT5LayerFF(model.config) return model