megatron_patch/model/mixtral_bak/moe/experts.py (136 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
from ..transformer.mlp import MLP, MLPSubmodules
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts: int, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
# How many feature each rank holds for fc1 and fc2, respectively.
tp_size = parallel_state.get_tensor_model_parallel_world_size()
fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1,
config.init_method,
partition_dim=1,
expert_parallel=self.expert_parallel,
)
_initialize_affine_weight_gpu(
self.weight2,
config.output_layer_init_method,
partition_dim=0,
expert_parallel=self.expert_parallel,
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
"""
Forward pass for the GroupedMLP module.
Args:
permuted_local_hidden_states (torch.Tensor): The input hidden states with dimensions suited
for expert parallelism. It's typically a result of permuting the original hidden states
to align tokens with their corresponding experts.
tokens_per_expert (list of int): Number of tokens assigned to each expert. This is used to
manage the distribution of tokens across the experts in the grouped GEMM operation.
Returns:
torch.Tensor: The output of the MLP after processing by the local experts.
None: Placeholder for any additional output, for compatibility with other modules.
"""
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
return fc2_output, None
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.add_bias = config.add_bias_linear
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
"""
Forward pass for the SequentialMLP module. It processes the input hidden states
using a sequence of MLP experts. Each expert operates on a contiguous slice
of the input corresponding to the tokens it is responsible for.
Args:
permuted_local_hidden_states (torch.Tensor): Tensor containing hidden states
that have been permuted so that tokens processed by the same expert are contiguous.
tokens_per_expert (torch.Tensor): Tensor indicating the number of tokens that
each expert is responsible for processing.
Returns:
Tupletorch.Tensor, torch.Tensor: A tuple containing two tensors. The first tensor
is the output from the experts after processing the hidden states. The second tensor
is the output bias from the experts if `add_bias` is True; otherwise, it is None.
"""
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)
output_local[start:end] = output
if self.add_bias and self.add_bias_fc:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local