megatron_patch/model/mixtral_bak/moe/experts.py (136 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch.nn.parameter import Parameter from megatron.core import parallel_state from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, _initialize_affine_weight_gpu, ) from megatron.core.tensor_parallel.utils import divide from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe import grouped_gemm_util as gg from megatron.core.transformer.transformer_config import TransformerConfig from ..transformer.mlp import MLP, MLPSubmodules class GroupedMLP(MegatronModule): """An efficient implementation of the Experts layer using CUTLASS GroupedGEMM. This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency. """ def __init__(self, num_local_experts: int, config: TransformerConfig): super().__init__(config=config) self.config: TransformerConfig = config self.num_local_experts = num_local_experts gg.assert_grouped_gemm_is_available() assert ( config.add_bias_linear == False ), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead." self.expert_parallel = config.expert_model_parallel_size > 1 if self.config.gated_linear_unit: def glu(x): x = torch.chunk(x, 2, dim=-1) return self.config.activation_func(x[0]) * x[1] self.activation_func = glu else: self.activation_func = self.config.activation_func # How many feature each rank holds for fc1 and fc2, respectively. tp_size = parallel_state.get_tensor_model_parallel_world_size() fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts if config.gated_linear_unit: # Project to 4h. If using swiglu double the output width, # see https://arxiv.org/pdf/2002.05202.pdf fc1_output_size *= 2 fc1_output_size_per_partition = divide(fc1_output_size, tp_size) fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts fc2_input_size_per_partition = divide(fc2_input_size, tp_size) # Note: The current kernel implementations of grouped_gemm # does not support transposition with CUTLASS grouped GEMM # (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358) # and as a result we avoid allocate the transpose of weights. # Initialize weight. if config.use_cpu_initialization: self.weight1 = Parameter( torch.empty( self.config.hidden_size, fc1_output_size_per_partition, dtype=config.params_dtype, ) ) self.weight2 = Parameter( torch.empty( fc2_input_size_per_partition, self.config.hidden_size, dtype=config.params_dtype, ) ) if config.perform_initialization: _initialize_affine_weight_cpu( self.weight1, self.config.hidden_size, fc1_output_size, fc1_output_size_per_partition, partition_dim=1, init_method=config.init_method, params_dtype=config.params_dtype, ) _initialize_affine_weight_cpu( self.weight2, fc2_input_size, self.config.hidden_size, fc2_input_size_per_partition, partition_dim=0, init_method=config.output_layer_init_method, params_dtype=config.params_dtype, ) else: self.weight1 = Parameter( torch.empty( self.config.hidden_size, fc1_output_size_per_partition, device=torch.cuda.current_device(), dtype=config.params_dtype, ) ) self.weight2 = Parameter( torch.empty( fc2_input_size_per_partition, self.config.hidden_size, device=torch.cuda.current_device(), dtype=config.params_dtype, ) ) if config.perform_initialization: _initialize_affine_weight_gpu( self.weight1, config.init_method, partition_dim=1, expert_parallel=self.expert_parallel, ) _initialize_affine_weight_gpu( self.weight2, config.output_layer_init_method, partition_dim=0, expert_parallel=self.expert_parallel, ) setattr(self.weight1, 'allreduce', not self.expert_parallel) setattr(self.weight2, 'allreduce', not self.expert_parallel) def forward(self, permuted_local_hidden_states, tokens_per_expert): """ Forward pass for the GroupedMLP module. Args: permuted_local_hidden_states (torch.Tensor): The input hidden states with dimensions suited for expert parallelism. It's typically a result of permuting the original hidden states to align tokens with their corresponding experts. tokens_per_expert (list of int): Number of tokens assigned to each expert. This is used to manage the distribution of tokens across the experts in the grouped GEMM operation. Returns: torch.Tensor: The output of the MLP after processing by the local experts. None: Placeholder for any additional output, for compatibility with other modules. """ # Reshape the weights for the grouped GEMMs. w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1) w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size) fc1_output = gg.ops.gmm(permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False) intermediate_parallel = self.activation_func(fc1_output) fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False) return fc2_output, None class SequentialMLP(MegatronModule): """An implementation of the Experts layer using a sequence of MLP layers. This class executes each expert sequentially. """ def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules): super().__init__(config=config) self.add_bias = config.add_bias_linear self.num_local_experts = num_local_experts self.local_experts = torch.nn.ModuleList() for _ in range(self.num_local_experts): expert = MLP(self.config, submodules, is_expert=True) self.local_experts.append(expert) def forward(self, permuted_local_hidden_states, tokens_per_expert): """ Forward pass for the SequentialMLP module. It processes the input hidden states using a sequence of MLP experts. Each expert operates on a contiguous slice of the input corresponding to the tokens it is responsible for. Args: permuted_local_hidden_states (torch.Tensor): Tensor containing hidden states that have been permuted so that tokens processed by the same expert are contiguous. tokens_per_expert (torch.Tensor): Tensor indicating the number of tokens that each expert is responsible for processing. Returns: Tupletorch.Tensor, torch.Tensor: A tuple containing two tensors. The first tensor is the output from the experts after processing the hidden states. The second tensor is the output bias from the experts if `add_bias` is True; otherwise, it is None. """ output_local = torch.zeros_like(permuted_local_hidden_states) output_bias_local = None if self.add_bias: output_bias_local = torch.zeros_like(permuted_local_hidden_states) cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) # Insert zero at the begining for offset index's convenience zero_tensor = torch.zeros(1, dtype=torch.long) cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) for expert_num, expert in enumerate(self.local_experts): start = cumsum_num_tokens[expert_num] end = cumsum_num_tokens[expert_num + 1] hidden = permuted_local_hidden_states[start:end] output, output_bias = expert(hidden) output_local[start:end] = output if self.add_bias and self.add_bias_fc: output_bias = output_bias.expand_as(output) output_bias_local[start:end, :] = output_bias return output_local, output_bias_local