megatron_patch/model/qwen2/moe/experts.py (316 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from functools import partial
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
)
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
)
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from ..transformer.mlp import MLP, MLPSubmodules, apply_swiglu_sharded_factory
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts: int, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
if self.config.activation_func not in (F.silu, F.gelu):
raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
@jit_fuser
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
# How many feature each rank holds for fc1 and fc2, respectively.
if config.moe_extended_tp:
tp_size = parallel_state.get_tensor_and_expert_parallel_world_size()
else:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1,
config.init_method,
partition_dim=1,
expert_parallel=self.expert_parallel,
)
_initialize_affine_weight_gpu(
self.weight2,
config.output_layer_init_method,
partition_dim=0,
expert_parallel=self.expert_parallel,
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False
)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
h = self.activation_func(h)
h = torch.matmul(h, w2)
fc2_output = h
return fc2_output, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
raise NotImplementedError(
'Currently distributed checkpointing is not supported for GroupedMLP'
)
class TEGroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using TE's GroupedLinear.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.moe_extended_tp = config.moe_extended_tp
self.num_local_experts = num_local_experts
self.input_size = self.config.hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.num_local_experts,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name='fc1',
)
self.activation_func = self.config.activation_func
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.num_local_experts,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name='fc2',
)
def remove_extra_states_check(self, incompatible_keys):
"""
Remove extra _extra_state from unexpected keys.
These keys are for dist ckpt compatibility with SequentialMLP.
"""
keys = deepcopy(incompatible_keys.unexpected_keys)
for key in keys:
if '_extra_state' in key:
incompatible_keys.unexpected_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)
def forward(
self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward of TEGroupedMLP
Args:
permuted_local_hidden_states (torch.Tensor): The permuted input hidden states of the
local experts.
tokens_per_expert (torch.Tensor): The number of tokens per expert.
Return:
output (torch.Tensor): The output of the local experts.
"""
tokens_per_expert = tokens_per_expert.tolist()
intermediate_parallel, bias_parallel = self.linear_fc1(
permuted_local_hidden_states, tokens_per_expert
)
if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert)
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Maps local expert to global experts.
The sharded state dict is interchangable with SequentialMLP's.
"""
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
)
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{name}.', sharded_offsets, metadata)
if name == 'linear_fc1' and self.config.gated_linear_unit:
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
ep_axis = len(sharded_offsets)
for i in range(self.num_local_experts):
new_sharded_offsets = (
*sharded_offsets,
(ep_axis, local_expert_indices_offset + i, num_global_experts),
)
for k in (f'{name}.weight{i}', f'{name}.bias{i}'):
if k in sub_sd:
sub_sd[k] = apply_swiglu_sharded_factory(sub_sd[k], new_sharded_offsets)
# Add prefix here to match sequential's keys
replace_prefix_for_sharding(sub_sd, f'{name}.', f'{prefix}experts.{name}.')
sharded_state_dict.update({f"{prefix}{k}": v for k, v in sub_sd.items()})
return sharded_state_dict
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.add_bias = config.add_bias_linear
self.moe_extended_tp = config.moe_extended_tp
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)
output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Maps local expert to global experts. """
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
)
sharded_state_dict = {}
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
expert_sharded_prefix = f'{prefix}experts.'
for expert_local_idx, expert in enumerate(self.local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_state_dict_prefix = f'{prefix}local_experts.{expert_local_idx}.'
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
expert_state_dict = expert.sharded_state_dict(
expert_state_dict_prefix, expert_sharded_offsets, metadata
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(
expert_state_dict, expert_state_dict_prefix, expert_sharded_prefix
)
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in expert_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_data_modulo_expert_parallel_rank(),
)
sharded_state_dict.update(expert_state_dict)
return sharded_state_dict