megatron_patch/model/qwen1_5/transformer/mlp.py (164 lines of code) (raw):
# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
)
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
@dataclass
class MLPSubmodules:
linear_fc1: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class MLP(MegatronModule):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
is_expert: bool = False,
input_size: int = None,
is_shared_expert: bool = False,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size != None else self.config.hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
if self.config.moe_ffn_hidden_size is not None:
if not is_shared_expert:
ffn_hidden_size = self.config.moe_ffn_hidden_size
else:
ffn_hidden_size = self.config.shared_moe_ffn_hidden_size
else:
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc1',
)
self.activation_func = self.config.activation_func
if self.config.moe_ffn_hidden_size is not None:
if not is_shared_expert:
ffn_hidden_size = self.config.moe_ffn_hidden_size
else:
ffn_hidden_size = self.config.shared_moe_ffn_hidden_size
else:
ffn_hidden_size = self.config.ffn_hidden_size
self.linear_fc2 = build_module(
submodules.linear_fc2,
ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(intermediate_parallel, bias_parallel)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = {}
for name, module in self._modules.items():
if name == 'linear_fc1' and self.config.gated_linear_unit:
sub_sd = self._sharded_state_dict_for_glu(
name, module, prefix, sharded_offsets, metadata
)
else:
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
def _sharded_state_dict_for_glu(
self,
module_name: str,
module: torch.nn.Module,
prefix: str,
sharded_offsets: Tuple[Tuple[int, int, int]],
metadata: Optional[dict] = None,
):
assert module_name == 'linear_fc1', module_name
sharded_state_dict = module.sharded_state_dict(
f'{prefix}{module_name}.', sharded_offsets, metadata
)
weight_key = f'{prefix}{module_name}.weight'
prev_sh_ten = sharded_state_dict[weight_key]
# We must split the tensor into 2 parts, each sharded separately.
# This requires a ShardedTensorFactory which `chunk`s during saving
# and `cat`s during loading
tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_shard_axis = 0
prepend_axis_num = len(sharded_offsets)
def sh_ten_build_fn(key: str, t: torch.Tensor, replica_id: ReplicaId):
offset_w = (tp_shard_axis + prepend_axis_num, tp_rank, tp_size * 2)
offset_v = (tp_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2)
with torch.no_grad():
tensor_w, tensor_v = torch.chunk(t, 2, dim=tp_shard_axis)
return [
ShardedTensor.from_rank_offsets(
key,
tensor_w,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
tensor_v,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
return torch.cat(sub_state_dict)
sharded_state_dict[weight_key] = ShardedTensorFactory(
prev_sh_ten.key,
prev_sh_ten.data,
sh_ten_build_fn,
sh_ten_merge_fn,
prev_sh_ten.replica_id,
)
return sharded_state_dict