megatron_patch/model/qwen2/transformer/mlp.py (258 lines of code) (raw):

# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.mapping import ( ReplicaId, ShardedStateDict, ShardedTensorFactory, ) from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig @dataclass class MLPSubmodules: linear_fc1: Union[ModuleSpec, type] = None linear_fc2: Union[ModuleSpec, type] = None class MLP(MegatronModule): """ MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None. We use the following notation: h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length """ def __init__( self, config: TransformerConfig, submodules: MLPSubmodules, is_expert: bool = False, input_size: int = None, is_shared_expert: bool = False, ): super().__init__(config=config) self.config: TransformerConfig = config self.input_size = input_size if input_size != None else self.config.hidden_size # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf if self.config.moe_ffn_hidden_size is not None: if not is_shared_expert: ffn_hidden_size = self.config.moe_ffn_hidden_size else: ffn_hidden_size = self.config.shared_moe_ffn_hidden_size else: ffn_hidden_size = self.config.ffn_hidden_size if self.config.gated_linear_unit: ffn_hidden_size *= 2 self.linear_fc1 = build_module( submodules.linear_fc1, self.input_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=True, is_expert=is_expert, tp_comm_buffer_name='fc1', ) self.activation_func = self.config.activation_func if self.config.moe_ffn_hidden_size is not None: if not is_shared_expert: ffn_hidden_size = self.config.moe_ffn_hidden_size else: ffn_hidden_size = self.config.shared_moe_ffn_hidden_size else: ffn_hidden_size = self.config.ffn_hidden_size self.linear_fc2 = build_module( submodules.linear_fc2, ffn_hidden_size, self.config.hidden_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, input_is_parallel=True, skip_bias_add=True, is_expert=is_expert, tp_comm_buffer_name='fc2', ) def forward(self, hidden_states): # [s, b, 4 * h/p] intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) if self.config.bias_activation_fusion: if self.activation_func == F.gelu: if self.config.gated_linear_unit: intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel) else: assert self.config.add_bias_linear is True intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) elif self.activation_func == F.silu and self.config.gated_linear_unit: intermediate_parallel = bias_swiglu_impl( intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store, ) else: raise ValueError("Only support fusion of gelu and swiglu") else: if bias_parallel is not None: intermediate_parallel = intermediate_parallel + bias_parallel if self.config.gated_linear_unit: def glu(x): x = torch.chunk(x, 2, dim=-1) return self.config.activation_func(x[0]) * x[1] intermediate_parallel = glu(intermediate_parallel) else: intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, output_bias = self.linear_fc2(intermediate_parallel) return output, output_bias def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None ) -> ShardedStateDict: sharded_state_dict = {} for name, module in self._modules.items(): if name == 'linear_fc1' and self.config.gated_linear_unit: sub_sd = self._sharded_state_dict_for_glu( name, module, prefix, sharded_offsets, metadata ) else: sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata) sharded_state_dict.update(sub_sd) return sharded_state_dict def _sharded_state_dict_for_glu( self, module_name: str, module: torch.nn.Module, prefix: str, sharded_offsets: Tuple[Tuple[int, int, int]], metadata: Optional[dict] = None, ): assert module_name == 'linear_fc1', module_name sharded_state_dict = module.sharded_state_dict( f'{prefix}{module_name}.', sharded_offsets, metadata ) weight_key = f'{prefix}{module_name}.weight' prev_sh_ten = sharded_state_dict[weight_key] # We must split the tensor into 2 parts, each sharded separately. # This requires a ShardedTensorFactory which `chunk`s during saving # and `cat`s during loading tp_rank = parallel_state.get_tensor_model_parallel_rank() tp_size = parallel_state.get_tensor_model_parallel_world_size() tp_shard_axis = 0 prepend_axis_num = len(sharded_offsets) def sh_ten_build_fn(key: str, t: torch.Tensor, replica_id: ReplicaId): offset_w = (tp_shard_axis + prepend_axis_num, tp_rank, tp_size * 2) offset_v = (tp_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2) with torch.no_grad(): tensor_w, tensor_v = torch.chunk(t, 2, dim=tp_shard_axis) return [ ShardedTensor.from_rank_offsets( key, tensor_w, *sharded_offsets, offset_w, replica_id=replica_id, prepend_axis_num=prepend_axis_num, ), ShardedTensor.from_rank_offsets( key, tensor_v, *sharded_offsets, offset_v, replica_id=replica_id, prepend_axis_num=prepend_axis_num, ), ] def sh_ten_merge_fn(sub_state_dict): with torch.no_grad(): return torch.cat(sub_state_dict) sharded_state_dict[weight_key] = ShardedTensorFactory( prev_sh_ten.key, prev_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, prev_sh_ten.replica_id, ) return sharded_state_dict def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets): # We must split the tensor into 2 parts, each sharded separately. # This requires a ShardedTensorFactory which `chunk`s during saving # and `cat`s during loading tp_rank = parallel_state.get_tensor_model_parallel_rank() tp_size = parallel_state.get_tensor_model_parallel_world_size() swiglu_shard_axis = 0 prepend_axis_num = len(sharded_offsets) original_shape = original_sh_ten.local_shape original_numel = int(np.prod(original_shape)) @torch.no_grad() def sh_ten_build_fn( key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice] ): offset_w = (swiglu_shard_axis + prepend_axis_num, tp_rank, tp_size * 2) offset_v = (swiglu_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2) if flattened_range is None: tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis) return [ ShardedTensor.from_rank_offsets( key, tensor_w, *sharded_offsets, offset_w, replica_id=replica_id, prepend_axis_num=prepend_axis_num, ), ShardedTensor.from_rank_offsets( key, tensor_v, *sharded_offsets, offset_v, replica_id=replica_id, prepend_axis_num=prepend_axis_num, ), ] else: # Here we need to map a slice `t` (`flattened_range` specifies slice start and stop) # of the *original* flattened tensor into slices `w` and `v` of chunked # and flattened tensor. # Example: # If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`, # then `t` has shape `(56,)` and we need to create 2 tensors: # w: first 32 elements of `t` with flattened_range slice(8, 40) # v: last 24 elements of `t` with flattened_range slice(0, 24) # Global offsets are the same as in the non-flattened case assert t.ndim == 1, (key, t.shape) non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:]) chunk_numel = original_numel // 2 result = [] if flattened_range.start < chunk_numel: # Non-empty `w` chunk tensor_w = t[: chunk_numel - flattened_range.start] flattened_range_w = slice( flattened_range.start, min(chunk_numel, flattened_range.stop) ) assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start result.append( ShardedTensor.from_rank_offsets_flat( key, tensor_w, non_flat_local_shape, *sharded_offsets, offset_w, replica_id=replica_id, prepend_axis_num=prepend_axis_num, flattened_range=flattened_range_w, ) ) if flattened_range.stop > chunk_numel: # Non-empty `v` chunk tensor_v = t[-(flattened_range.stop - chunk_numel) :] flattened_range_v = slice( max(chunk_numel, flattened_range.start) - chunk_numel, flattened_range.stop - chunk_numel, ) assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, ( len(tensor_v), flattened_range_v, ) result.append( ShardedTensor.from_rank_offsets_flat( key, tensor_v, non_flat_local_shape, *sharded_offsets, offset_v, replica_id=replica_id, prepend_axis_num=prepend_axis_num, flattened_range=flattened_range_v, ) ) assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape) return result def sh_ten_merge_fn(sub_state_dict): with torch.no_grad(): return torch.cat(sub_state_dict) return ShardedTensorFactory( original_sh_ten.key, original_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, original_sh_ten.replica_id, )