megatron_patch/model/glm130b/transformer.py (875 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from contextlib import nullcontext from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from megatron import core from megatron import get_args from megatron import get_num_microbatches from megatron.core import mpu, tensor_parallel from megatron.model import LayerNorm from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnType from megatron.model.enums import LayerType from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.module import MegatronModule from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb from megatron.model.utils import attention_mask_func from megatron.model.utils import openai_gelu from megatron.model.utils import erf_gelu from megatron.core.enums import ModelType try: from einops import rearrange except ImportError: rearrange = None try: from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: flash_attn_unpadded_func = None """ We use the following notation throughout this file: h: hidden size n: number of attention heads p: number of model parallel partitions np: n/p hp: h/p hn: h/n b: batch size s: sequence length l: number of layers Transformer takes input of size [s, b, h] and returns a tensor of the same size. We use the following arguments: hyperparameters: transformer hyperparameters """ class DropPath(MegatronModule): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=0.): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, hidden_state): if self.drop_prob == 0. or not self.training: return hidden_state keep_prob = 1 - self.drop_prob # work with diff dim tensors, not just 2D ConvNets # hidden_state: [s, b, h] shape = (1, ) + ( hidden_state.shape[1], ) + (1, ) * (hidden_state.ndim - 2) random_tensor =\ keep_prob + torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) random_tensor.floor_() # binarize output = hidden_state.div(keep_prob) * random_tensor return output def _args_to_kwargs(): args = get_args() common_kwargs = { 'params_dtype': args.params_dtype, 'use_cpu_initialization': args.use_cpu_initialization, 'perform_initialization': args.perform_initialization, 'gradient_accumulation_fusion': args.gradient_accumulation_fusion, 'sequence_parallel_enabled': args.sequence_parallel, } return common_kwargs class ParallelMLP(MegatronModule): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() self.add_bias = args.add_bias_linear inner_hidden_size = 4 * args.hidden_size * 2 // 3 self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( args.hidden_size, 2 * inner_hidden_size, bias=self.add_bias, gather_output=False, init_method=init_method, skip_bias_add=False, async_tensor_model_parallel_allreduce=args. async_tensor_model_parallel_allreduce, **_args_to_kwargs()) self.bias_gelu_fusion = False self.activation_func = None self.swiglu = args.swiglu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu elif args.swiglu: def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] self.activation_func = swiglu elif args.squared_relu: def squared_relu(x): return torch.pow(F.relu(x), 2) self.activation_func = squared_relu elif args.geglu: def geglu(x): x1, x2 = x.chunk(2, dim=(x.ndim - 1)) return x1 * F.gelu(x2) self.activation_func = geglu else: self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu # Project back to h. self.dense_4h_to_h = tensor_parallel.RowParallelLinear( inner_hidden_size, args.hidden_size, bias=self.add_bias, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=False, **_args_to_kwargs()) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h( hidden_states) if self.bias_gelu_fusion: assert self.add_bias is True assert self.activation_func == F.gelu intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) else: if bias_parallel is not None: intermediate_parallel = intermediate_parallel + bias_parallel intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias class CoreAttention(MegatronModule): def __init__(self, layer_number, attn_mask_type=AttnMaskType.padding): super(CoreAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attn_mask_type = attn_mask_type self.sequence_parallel = args.sequence_parallel projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = core.utils.divide( projection_size, world_size) self.hidden_size_per_attention_head = core.utils.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = core.utils.divide( args.num_attention_heads, world_size) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): # =================================== # Raw attention scores. [b, np, s, s] # =================================== query_key_layer_scaling_coeff = float(self.layer_number) query_layer = query_layer / ( math.sqrt(self.hidden_size_per_attention_head) * query_key_layer_scaling_coeff) # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( (output_size[0] * output_size[1], output_size[2], output_size[3]), query_layer.dtype, 'mpu') matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=1.0) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. if not self.sequence_parallel: with tensor_parallel.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) else: attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class FlashSelfAttention(torch.nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): super().__init__() assert flash_attn_unpadded_func is not None, ( 'Please install FlashAttention first, ' 'e.g., with pip install flash-attn') assert rearrange is not None,\ 'Please install einops first, e.g., with pip install einops' self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout def forward(self, q, k, v): """Implements the multihead softmax attention. Arguments --------- q, k, v: The tensor containing the query, key, and value. (B, S, H, D) """ assert all( (i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) assert all((i.is_cuda for i in (q, k, v))) batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) if self.training: # during training q,k,v always have same seqlen assert seqlen_k == seqlen_q is_causal = self.causal cu_seqlens_k = cu_seqlens_q else: # turn off FA causal mask after first # inference autoregressive iteration # only on first autoregressive step # q,k,v have same seqlen is_causal = seqlen_q == seqlen_k cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device) self.dropout_p = 0 output = flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, self.dropout_p, softmax_scale=self.softmax_scale, causal=is_causal) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) return output class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): super(ParallelAttention, self).__init__() args = get_args() self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype self.sequence_parallel = args.sequence_parallel self.use_flash_attn = args.use_flash_attn if self.use_flash_attn: if flash_attn_unpadded_func is None: raise ImportError( 'FlashAttention is not installed, please install with ' 'pip install flash-attn') assert attention_type == AttnType.self_attn, ( 'FlashAttention code path only supports ' 'self-attention for now') assert self.attn_mask_type == AttnMaskType.causal, ( 'FlashAttention code path only ' 'supports causal mask for now') if rearrange is None: raise ImportError('einops is not installed, ' 'please install with pip install einops') projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = core.utils.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = core.utils.divide( args.num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, bias=args.add_bias_linear, gather_output=False, init_method=init_method, async_tensor_model_parallel_allreduce=args. async_tensor_model_parallel_allreduce, **_args_to_kwargs()) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( args.hidden_size, projection_size, bias=args.add_bias_linear, gather_output=False, init_method=init_method, async_tensor_model_parallel_allreduce=args. async_tensor_model_parallel_allreduce, **_args_to_kwargs()) self.key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 2 * projection_size, bias=args.add_bias_linear, gather_output=False, init_method=init_method, async_tensor_model_parallel_allreduce=args. async_tensor_model_parallel_allreduce, **_args_to_kwargs()) self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type) self.checkpoint_core_attention =\ args.recompute_granularity == 'selective' if self.use_flash_attn: self.core_attention_flash = FlashSelfAttention( causal=True, attention_dropout=args.attention_dropout) # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, args.hidden_size, bias=args.add_bias_linear, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=False, **_args_to_kwargs()) def _checkpointed_attention_forward(self, query_layer, key_layer, value_layer, attention_mask, rotary_pos_emb=None): """Forward method with activation checkpointing.""" def custom_forward(*inputs): query_layer = inputs[0] key_layer = inputs[1] value_layer = inputs[2] attention_mask = inputs[3] output_ = self.core_attention(query_layer, key_layer, value_layer, attention_mask) return output_ q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ else rotary_pos_emb hidden_states = tensor_parallel.checkpoint(custom_forward, False, query_layer, key_layer, value_layer, attention_mask, q_pos_emb, k_pos_emb) return hidden_states def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty(inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None, rotary_pos_emb=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= is_first_step = False if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory) is_first_step = True else: inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] # ===================== # Query, Key, and Value # ===================== if self.attention_type == AttnType.self_attn: # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( mixed_x_layer, 3) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( mixed_kv_layer, 2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== # duplicate the pos_emb for self attention if rotary_pos_emb is not None: if isinstance(rotary_pos_emb, tuple): rotary_pos_emb = rotary_pos_emb else: rotary_pos_emb = ((rotary_pos_emb, ) * 2) if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] # adjust the key rotary positional embedding if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb # need to cross check this condition during inference # if not set_inference_key_value_memory: if not is_first_step: # In inference, we compute one token at a time. # Select the correct positional embedding # (only the last token in the sequence) q_pos_emb = q_pos_emb[sequence_end - 1:sequence_end] else: # In the first forward pass of inference, # we use the entire provided prefix. # q_pos_emb here has the rope embeddings of the entire # prefix + to-be-generated output so # we slice to just the prefix. q_pos_emb = q_pos_emb[:sequence_end, :, :, :] k_pos_emb = k_pos_emb[:sequence_end, :, :, :] rotary_pos_emb = (q_pos_emb, k_pos_emb) # ================================== # core attention computation # ================================== # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) if not self.use_flash_attn: if self.checkpoint_core_attention: context_layer = self._checkpointed_attention_forward( query_layer, key_layer, value_layer, attention_mask) else: context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) else: q, k, v = [ rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer) ] if not self.sequence_parallel: with tensor_parallel.get_cuda_rng_tracker().fork(): context_layer = self.core_attention_flash(q, k, v) else: context_layer = self.core_attention_flash(q, k, v) context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) return output, bias def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor if bias is not None: x = x + bias out = torch.nn.functional.dropout(x, p=prob, training=training) out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train(x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) class ParallelTransformerLayer(MegatronModule): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, drop_path_rate=0.): args = get_args() super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.layer_type = layer_type self.num_layers = args.num_layers self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm, sequence_parallel=args.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p) # Self attention. self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0.0 else None # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm, sequence_parallel=args.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p) if self.layer_type == LayerType.decoder: self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm, sequence_parallel=args.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p) # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) # Set bias+dropout+add fusion grad_enable execution handler. TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) self.bias_dropout_add_exec_handler =\ nullcontext if use_nvfuser else torch.enable_grad def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None, rotary_pos_emb=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb) alpha = (2 * self.num_layers)**0.5 # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output * alpha else: residual = hidden_states layernorm_input = attention_output + residual # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output * alpha else: residual = layernorm_input output = mlp_output + residual # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, # which rejects view tensors. While making a viewless tensor here # won't result in memory savings (like the data loader, or # p2p_communication), it serves to document the origin of this # 'view' tensor. output = core.utils.make_viewless_tensor( inp=output, requires_grad=output.requires_grad, keep_graph=True) return output def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): """Compute the number of transformer layers resident on the current rank.""" if mpu.get_pipeline_model_parallel_world_size() > 1: if is_encoder_and_decoder_model: assert args.pipeline_model_parallel_split_rank is not None # When a standalone embedding stage is used, a rank is taken from # the encoder's ranks, to be used for the encoder's embedding # layer. This way, the rank referenced by the 'split rank' remains # the same whether or not a standalone embedding stage is used. num_ranks_in_encoder = (args.pipeline_model_parallel_split_rank - 1 if args.standalone_embedding_stage else args.pipeline_model_parallel_split_rank) num_ranks_in_decoder =\ args.transformer_pipeline_model_parallel_size -\ num_ranks_in_encoder assert args.encoder_num_layers % num_ranks_in_encoder == 0 assert args.decoder_num_layers % num_ranks_in_decoder == 0 if mpu.is_pipeline_stage_before_split(): num_layers = (0 if args.standalone_embedding_stage and mpu.get_pipeline_model_parallel_rank() == 0 else args.encoder_num_layers // num_ranks_in_encoder) else: num_layers = args.decoder_num_layers // num_ranks_in_decoder else: assert args.num_layers == args.encoder_num_layers assert args.num_layers %\ args.transformer_pipeline_model_parallel_size == 0 num_layers = (0 if args.standalone_embedding_stage and mpu.get_pipeline_model_parallel_rank() == 0 else args.num_layers // args.transformer_pipeline_model_parallel_size) else: if not is_decoder: num_layers = args.encoder_num_layers else: num_layers = args.decoder_num_layers return num_layers class ParallelTransformer(MegatronModule): """Transformer class.""" def __init__(self, init_method, output_layer_init_method, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, post_layer_norm=True, pre_process=True, post_process=True, drop_path_rate=0.0): super(ParallelTransformer, self).__init__() args = get_args() self.layer_type = layer_type self.model_type = args.model_type self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection self.post_layer_norm = post_layer_norm self.pre_process = pre_process self.post_process = post_process self.input_tensor = None self.drop_path_rate = drop_path_rate self.transformer_impl = args.transformer_impl # Store activation checkpoiting flag. self.recompute_granularity = args.recompute_granularity self.recompute_method = args.recompute_method self.recompute_num_layers = args.recompute_num_layers self.distribute_saved_activations = \ args.distribute_saved_activations and not args.sequence_parallel self.sequence_parallel = args.sequence_parallel # Transformer Engine Init. if self.transformer_impl == 'transformer_engine': global transformer_engine import transformer_engine self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid self.fp8_recipe = None self.fp8_group = None if self.use_fp8: self.fp8_group = mpu.get_data_parallel_group() if args.fp8_e4m3: fp8_format = transformer_engine.common.recipe.Format.E4M3 elif args.fp8_hybrid: fp8_format = transformer_engine.common.recipe.Format.HYBRID self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( margin=args.fp8_margin, interval=args.fp8_interval, fp8_format=fp8_format, amax_history_len=args.fp8_amax_history_len, amax_compute_algo=args.fp8_amax_compute_algo, override_linear_precision=(False, False, not args.fp8_wgrad), ) self.num_microbatches_in_previous_step = -1 self.microbatch_count = 0 self.checkpoint_core_attention =\ args.recompute_granularity == 'selective' # Number of layers. self.num_layers = _get_num_layers( args, args.model_type == ModelType.encoder_and_decoder, layer_type == LayerType.decoder) self.drop_path_rates = [ rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers) ] # Transformer layers. def build_layer(layer_number): if args.transformer_impl == 'local': return ParallelTransformerLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) else: return transformer_engine.pytorch.TransformerLayer( args.hidden_size, args.ffn_hidden_size, args.num_attention_heads, layernorm_epsilon=args.layernorm_epsilon, hidden_dropout=args.hidden_dropout, attention_dropout=args.attention_dropout, init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number, kv_channels=args.kv_channels, self_attn_mask_type=self_attn_mask_type.name, tp_group=mpu.get_tensor_model_parallel_group(), get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, fuse_wgrad_accumulation=args.gradient_accumulation_fusion, apply_query_key_layer_scaling=args. apply_query_key_layer_scaling, attention_softmax_in_fp32=args.attention_softmax_in_fp32, seq_length=args.seq_length, micro_batch_size=args.micro_batch_size, sequence_parallel=args.sequence_parallel, params_dtype=args.params_dtype, apply_residual_connection_post_layernorm=args. apply_residual_connection_post_layernorm, output_layernorm=False, layer_type='encoder', drop_path_rate=self.drop_path_rates[layer_number - 1], set_parallel_mode=True, fuse_qkv_params=True) if args.virtual_pipeline_model_parallel_size is not None: assert args.num_layers % \ args.virtual_pipeline_model_parallel_size == 0 assert args.model_type != ModelType.encoder_and_decoder # Number of layers in each model chunk # is the number of layers in the stage, # divided by the number of model chunks in a stage. self.num_layers = \ self.num_layers // args.virtual_pipeline_model_parallel_size # With 8 layers, 2 stages, and 4 model chunks, # we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0] [2] [4] [6] # Stage 1: [1] [3] [5] [7] # With 8 layers, 2 stages, and 2 virtual stages, # we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] offset =\ mpu.get_virtual_pipeline_model_parallel_rank() * ( args.num_layers // args.virtual_pipeline_model_parallel_size) + ( mpu.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. if args.model_type == ModelType.encoder_and_decoder and \ mpu.get_pipeline_model_parallel_world_size() > 1: pipeline_rank = mpu.get_pipeline_model_parallel_rank() if layer_type == LayerType.encoder: offset = pipeline_rank * self.num_layers else: num_ranks_in_enc = args.pipeline_model_parallel_split_rank offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers else: offset = mpu.get_pipeline_model_parallel_rank( ) * self.num_layers self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process and self.post_layer_norm: # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm, sequence_parallel=args.sequence_parallel, apply_layernorm_1p=args.apply_layernorm_1p) def _get_layer(self, layer_number): return self.layers[layer_number] def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb, is_first_microbatch): """Forward method with activation checkpointing.""" def custom(start, end, is_transformer_engine=False): def custom_forward(*inputs): x_ = inputs[0] attention_mask = inputs[1] encoder_output = inputs[2] enc_dec_attn_mask = inputs[3] rotary_pos_emb = inputs[4] for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb=rotary_pos_emb) return x_ def custom_forward_transformer_engine(*args, **kwargs): return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs) if not is_transformer_engine: return custom_forward else: return custom_forward_transformer_engine if self.recompute_method == 'uniform': # Uniformly divide the total # number of Transformer layers and checkpoint # the input activation of each divided chunk. # A method to further reduce memory usage reducing checkpoints. layer = 0 while layer < self.num_layers: if self.transformer_impl == 'transformer_engine': tec = transformer_engine.pytorch.distributed.checkpoint hidden_states = tec( custom(layer, layer + self.recompute_num_layers, is_transformer_engine=True), self.distribute_saved_activations, tensor_parallel.get_cuda_rng_tracker, mpu.get_tensor_model_parallel_group(), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb) else: hidden_states = tensor_parallel.checkpoint( custom(layer, layer + self.recompute_num_layers), self.distribute_saved_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb) layer += self.recompute_num_layers elif self.recompute_method == 'block': # Checkpoint the input activation # of only a set number of individual # Transformer layers and skip the rest. # A method fully use the device memory # removing redundant re-computation. for layer in range(self.num_layers): if layer < self.recompute_num_layers: if self.transformer_impl == 'transformer_engine': tec = transformer_engine.pytorch.distributed.checkpoint hidden_states = tec( custom(layer, layer + 1, is_transformer_engine=True), self.distribute_saved_activations, tensor_parallel.get_cuda_rng_tracker, mpu.get_tensor_model_parallel_group(), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb=rotary_pos_emb) else: hidden_states = tensor_parallel.checkpoint( custom(layer, layer + 1), self.distribute_saved_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb=rotary_pos_emb) else: if self.transformer_impl == 'transformer_engine': hidden_states = custom( layer, layer + 1, is_transformer_engine=True)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb=rotary_pos_emb) else: hidden_states = custom(layer, layer + 1)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb=rotary_pos_emb) else: raise ValueError('Invalid activation recompute method.') return hidden_states def set_input_tensor(self, input_tensor): """Set input tensor to be used instead of forward()'s input. When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model's forward_step_func won't have it. This function is thus used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None, rotary_pos_emb=None): # hidden_states: [s, b, h] # Checks. if inference_params: assert self.recompute_granularity is None, \ 'inference does not work with activation checkpointing' if not self.pre_process: # See set_input_tensor() hidden_states = self.input_tensor # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' # above creates a view tensor, and '.contiguous()' is a pass-through. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # the need to make it viewless. # # However, we don't explicitly check mbs == 1 here because # make_viewless_tensor() has negligible overhead when its input # is already viewless. # # - For the 'else' case above, calling make_viewless_tensor() here is # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. hidden_states = core.utils.make_viewless_tensor( hidden_states, requires_grad=True, keep_graph=True, ) if self.sequence_parallel: rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: rng_context = nullcontext() with rng_context: # The fp8_autocast context manager is a no-op when enabled=True # The if...else serves to short # circuit name resolution for fp8_autocast with transformer_engine.pytorch.fp8_autocast( enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group) if self.use_fp8 else nullcontext( ): # Determine if the current iteration is first microbatch if self.num_microbatches_in_previous_step !=\ get_num_microbatches(): self.microbatch_count = 0 self.num_microbatches_in_previous_step = get_num_microbatches() is_first_microbatch =\ self.microbatch_count % get_num_microbatches() == 0 # Forward pass. if self.recompute_granularity == 'full': hidden_states = self._checkpointed_forward( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, rotary_pos_emb, is_first_microbatch) else: forward_kwargs = { 'encoder_output': encoder_output, 'enc_dec_attn_mask': enc_dec_attn_mask, 'inference_params': inference_params, 'rotary_pos_emb': rotary_pos_emb, } if self.transformer_impl == 'transformer_engine': forward_kwargs['is_first_microbatch'] =\ is_first_microbatch forward_kwargs['checkpoint_core_attention'] =\ self.checkpoint_core_attention for index in range(self.num_layers): layer = self._get_layer(index) hidden_states = layer(hidden_states, attention_mask, **forward_kwargs) # Skip counter update for eval and activation checkpointing if torch.is_grad_enabled() and self.training: self.microbatch_count += 1 # Final layer norm. if self.post_process and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states