megatron_patch/model/qwen2/moe/token_dispatcher.py (327 lines of code) (raw):

# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import abstractmethod from typing import List, Optional, Tuple import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel from megatron.core.transformer.moe.moe_utils import moe_gather, moe_scatter, permute, unpermute from megatron.core.transformer.transformer_config import TransformerConfig class MoETokenDispatcher: """ MoE Token Dispatcher """ def __init__(self, config: TransformerConfig) -> None: """ Initialize the MoE Token Dispatcher. """ self.config = config @abstractmethod def token_permutation( self, tokens: torch.Tensor, indices: torch.Tensor, ): """Dispatch tokens to experts. Args: tokens (torch.Tensor): Input tokens. indices (torch.Tensor): indices tensor. Returns: torch.Tensor: Tokens tensor. """ raise NotImplementedError("Dispatch function not implemented.") @abstractmethod def token_unpermutation( self, expert_output: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor, ): """Restores the expert output to its original ordering. Args: expert_output (torch.Tensor): The output tensor from the expert models. probs (torch.Tensor): Each token's score with each expert. indices (torch.Tensor): The indices used to reorder the expert output. Returns: (torch.Tensor, torch.Tensor): Unpermuted activation and optional bias. """ raise NotImplementedError("Restore function not implemented.") class MoEAllGatherTokenDispatcher(MoETokenDispatcher): """ AllGather Based Token dispatcher. """ def __init__( self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig, ) -> None: """ Initialize the zero token dropping router. """ super().__init__(config=config) self.num_local_experts = num_local_experts assert self.num_local_experts > 0, "Expected at least one expert" self.local_expert_indices = local_expert_indices assert len(self.local_expert_indices) > 0, "Expected at least one local expert index" self.router_topk = config.moe_router_topk self.add_bias = config.add_bias_linear # self.local_probs: probs of global token assignment to local experts. self.local_probs = None # self.indices: The indices of `local_indices` (which holds the un-sorted expert indices of tokens that local expert can process) that give its sorted order along dim 0. self.indices = None # self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where each element is True if it's between the local_expert_indices. Only useful when cross device token permutation is enabled and **AllGahter** is performed. self.global_local_map = None def token_permutation( self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor ): """Dispatch tokens to local experts. It's composed of two stages: (1) Permute the tokens across the expert parallel devices. After this stage, each device receives all of the tokens assigned to its local set of experts in its local HBM. (2) Permute the tokens locally so that they are grouped by their expert assignment. After the stage (1), the tokens are grouped by which device they came from. We re-order them locally for subsequent efficient computation. Args: hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize] max_prob: probs of local token assignment to global experts. max_ind: token assignment to local experts. Returns: permuted_local_hidden_states: Permutation of tokens to local experts group. tokens_per_expert: the number of tokens each local expert to process. """ self.hidden_shape = hidden_states.shape # [S/TP, B, H] -> [S*B/TP, H] hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) # Permute the tokens across the expert parallel devices. if (self.config.tensor_model_parallel_size > 1) or ( self.config.expert_model_parallel_size > 1 ): with torch.no_grad(): global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe( max_ind ) # Create a mask of mapping between global and local tokens where each # element is True if it's between the local_expert_indices global_local_mask = (global_indices >= self.local_expert_indices[0]) & ( global_indices <= self.local_expert_indices[-1] ) local_indices = global_indices.masked_select(global_local_mask) if self.router_topk > 1: # k > 1 global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob) self.local_probs = global_probs.masked_select(global_local_mask) else: self.local_probs = max_prob # [S*B/TP, H] -> [S*B, H] global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe( hidden_states, use_global_buffer=True ) # Reshape global_local_mask to be compatible with Tensor.gather global_local_map = global_local_mask.nonzero()[:, 0] self.global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1]) local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map) else: if self.router_topk > 1: global_local_mask = torch.ones_like(max_ind).bool() local_indices = max_ind.masked_select(global_local_mask) self.local_probs = max_prob.masked_select(global_local_mask) global_local_map = global_local_mask.nonzero()[:, 0] self.global_local_map = global_local_map.view(-1, 1).expand( -1, hidden_states.shape[-1] ) local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map) else: local_indices = max_ind self.local_probs = max_prob local_hidden_states = hidden_states self.global_local_map = None with torch.no_grad(): # The indices of local_indices that give its sorted order along dim 0. self.indices = torch.argsort(local_indices, dim=0) tokens_per_expert = torch.histc( local_indices, bins=self.num_local_experts, min=self.local_expert_indices[0], max=self.local_expert_indices[-1], ) tokens_per_expert = tokens_per_expert.cpu().to(torch.long) # Stage2: permute the tokens locally so that they are grouped by their expert assignment # Reshape indices to be compatible with Tensor.gather self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1]) if self.num_local_experts > 1: permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices) else: permuted_local_hidden_states = local_hidden_states return ( permuted_local_hidden_states, tokens_per_expert, ) def token_unpermutation( self, hidden_states: torch.Tensor, bias: torch.Tensor = None, ): """ Reverse process of `dispatch()` which permutes the ouput of local experts locallay and across expert parallel rank into the original order to produce the final output. Args: hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize], ouput of local experts. bias (optional): The bias tensor. Returns: output_total: un-permuted updated hidden states output from all local experts with shape of [SeqLen/TP, MBS, HiddenSize] """ # Stage1: unpermute the tokens and bias locally respectively. scores = self.local_probs.to(dtype=hidden_states.dtype) if self.num_local_experts > 1: assert self.indices.shape == hidden_states.shape unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices) else: unpermuted_local_hidden = hidden_states # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1. if self.router_topk > 1: unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1) unpermuted_local_bias = None if self.add_bias: assert bias is not None unpermuted_local_bias = torch.zeros_like(hidden_states) assert self.indices.shape == bias.shape unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias) if self.router_topk > 1: unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1) output_total = unpermuted_local_hidden output_bias_total = unpermuted_local_bias # Unpermute the tokens across expert parallel devices. if (self.config.tensor_model_parallel_size > 1) or ( self.config.expert_model_parallel_size > 1 ): assert ( self.global_local_map is not None ), "global_local_map is necessary for `AllGather`." ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size() # hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP) global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]] assert self.global_local_map.shape == unpermuted_local_hidden.shape unpermuted_global_hidden = moe_scatter.apply( unpermuted_local_hidden, self.global_local_map, global_hidden_shape ) output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe( unpermuted_global_hidden ) if self.add_bias: # Unpermute the bias across expert parallel devices. unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden) unpermuted_global_bias = unpermuted_global_bias.scatter_add( 0, self.global_local_map, unpermuted_local_bias ) output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe( unpermuted_global_bias ) # bias is duplicated across tensor parallelism ranks; # reduce scatter reduces bias across tensor parallel_ranks output_bias_total = ( output_bias_total / parallel_state.get_tensor_model_parallel_world_size() ) else: if self.router_topk > 1: global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]] unpermuted_global_hidden = torch.zeros( global_hidden_shape, dtype=hidden_states.dtype, device=torch.cuda.current_device(), ) output_total = unpermuted_global_hidden.scatter_add( 0, self.global_local_map, unpermuted_local_hidden ) if self.add_bias: unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden) output_bias_total = unpermuted_global_bias.scatter_add( 0, self.global_local_map, unpermuted_local_bias ) if self.router_topk == 1: output_total = output_total * scores output_total = output_total.view(self.hidden_shape) if self.add_bias: assert output_bias_total is not None if self.router_topk == 1: output_bias_total = output_bias_total * scores output_bias_total = output_bias_total.view(self.hidden_shape) else: output_bias_total = None return output_total, output_bias_total class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): """ AlltoAll Based Token dispatcher. """ def __init__( self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig, ) -> None: """ Initialize the AlltoAll token dispatcher. Args: num_local_experts (int): Number of local experts on the current device. local_expert_indices (List[int]): Indices of local experts on the current device. config (TransformerConfig): Configuration for the transformer model. """ super().__init__(config=config) self.hidden_shape = None self.num_input_tokens = None self.num_local_experts = num_local_experts self.num_experts = config.num_moe_experts assert self.num_local_experts > 0, "Expected at least one expert" self.local_expert_indices = local_expert_indices assert ( len(self.local_expert_indices) == self.num_local_experts ), "Invalid local expert indices" self.router_topk = config.moe_router_topk self.add_bias = config.add_bias_linear self.ep_size = config.expert_model_parallel_size self.probs = None self.input_splits = None self.output_splits = None self.num_global_tokens_per_local_expert = None # Token drop and padding. # We need to keep track of the token num if we drop tokens without padding them. self.num_out_tokens = None # Drop and pad the input to capacity. self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity if self.drop_and_pad: assert self.config.moe_expert_capacity_factor is not None self.capacity = None def preprocess(self, indices: torch.Tensor) -> torch.Tensor: """ Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices. It also initializes the necessary data structures for AlltoAll communication, such as input and output splits, and the mapping between global tokens and local experts. Args: indices (torch.Tensor): Tensor of indices mapping tokens to experts. Returns: torch.Tensor: Tensor containing the number of tokens assigned to local expert. """ num_local_tokens_per_expert = torch.histc( indices, bins=self.num_experts, min=0, max=self.num_experts ) # num_local_tokens_per_expert: [num_experts] ep_size = self.config.expert_model_parallel_size if self.drop_and_pad: # probs: [num_experts, capacity] self.capacity = self.probs.size(1) num_tokens_per_local_expert = torch.full( (self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long ) return num_tokens_per_local_expert elif self.config.moe_expert_capacity_factor is not None: self.num_out_tokens = num_local_tokens_per_expert.sum().cpu() if ep_size > 1: # =================================================== # Calculate input_splits, output_splits for alltoall-v. # =================================================== self.input_splits = ( num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts) .sum(axis=1) .to(torch.device("cpu")) .numpy() ) num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel( num_local_tokens_per_expert ).reshape(ep_size, self.num_experts) self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[ :, self.local_expert_indices ] self.output_splits = ( self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy() ) num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0).to( torch.device("cpu"), non_blocking=True ) # =================================================== # num_global_tokens_per_expert: [ep_size, num_experts] # num_global_tokens_per_local_expert: [ep_size, num_local_experts] # num_tokens_per_local_expert: [num_local_experts] # =================================================== else: self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( -1, self.num_experts ) num_tokens_per_local_expert = num_local_tokens_per_expert.to( torch.device("cpu"), non_blocking=True ) if self.num_local_experts > 1: expert_ids_per_ep_rank = torch.tensor( [i % self.num_local_experts for i in range(self.config.num_moe_experts)], dtype=torch.int32, device=torch.cuda.current_device(), ) self.global_input_tokens_local_experts_indices = torch.repeat_interleave( expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel() ) return num_tokens_per_local_expert def token_permutation( self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Dispatch tokens to local experts using AlltoAll communication. Args: hidden_states (torch.Tensor): Input token embeddings. probs (torch.Tensor): Probs of tokens assigned to experts. indices (torch.Tensor): Indices of tokens assigned to experts. Returns: Tuple[torch.Tensor, torch.Tensor]: - Permuted token embeddings for local experts. - Number of tokens per expert. """ # Preprocess: Get the metadata for communication, permutation and computation operations. self.hidden_shape = hidden_states.shape self.probs = probs assert probs.dim() == 2, "Expected 2D tensor for probs" assert indices.dim() == 2, "Expected 2D tensor for indices" hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) tokens_per_expert = self.preprocess(indices) # Perform tensor parallel AlltoAll communication # hidden_states: [S*B/TP, H] -> [S*B, H/TP] if parallel_state.get_tensor_model_parallel_world_size() > 1: hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states) # Permutation 1: input to AlltoAll input self.hiddden_shape_before_permute = hidden_states.shape permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute( hidden_states, indices, num_out_tokens=self.num_out_tokens, padded_mode=self.drop_and_pad, ) # Perform expert parallel AlltoAll communication global_input_tokens = tensor_parallel.all_to_all( parallel_state.get_expert_model_parallel_group(), permutated_local_input_tokens, self.output_splits, self.input_splits, ) # Permutation 2: Sort alltoall output by local experts when num_local_experts > 1. if self.num_local_experts > 1: if not self.drop_and_pad: global_input_tokens, self.reversed_global_input_permutation_mapping = permute( global_input_tokens, self.global_input_tokens_local_experts_indices ) else: global_input_tokens = global_input_tokens.reshape( self.ep_size, self.num_local_experts, self.capacity, -1 ) global_input_tokens = ( global_input_tokens.transpose(0, 1) .reshape(self.num_local_experts * self.ep_size * self.capacity, -1) .contiguous() ) # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] if parallel_state.get_tensor_model_parallel_world_size() > 1: global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region( global_input_tokens ) return global_input_tokens, tokens_per_expert def token_unpermutation( self, hidden_states: torch.Tensor, bias: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Reverse the token permutation to restore the original order. Args: hidden_states (torch.Tensor): Output from local experts. bias (torch.Tensor, optional): Bias tensor (not supported). Returns: Tuple[torch.Tensor, Optional[torch.Tensor]]: - Unpermuted token embeddings in the original order. - None (bias is not supported). """ assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher" # Perform tensor parallel Reduce-Scatter # hidden_states: [SEQL, H] -> [SEQL, H/TP] if parallel_state.get_tensor_model_parallel_world_size() > 1: hidden_states = tensor_parallel.reduce_scatter_last_dim_to_tensor_parallel_region( hidden_states ) # Unpermutation 2: expert output to AlltoAll input if self.num_local_experts > 1: if not self.drop_and_pad: hidden_states = unpermute( hidden_states, self.reversed_global_input_permutation_mapping, ) else: hidden_states = hidden_states.reshape( self.num_local_experts, self.ep_size, self.capacity, -1 ) hidden_states = ( hidden_states.transpose(0, 1) .reshape(self.ep_size * self.num_local_experts * self.capacity, -1) .contiguous() ) # Perform expert parallel AlltoAll communication # hidden_states: [SEQL, H] -> [SEQL, H/TP] permutated_local_input_tokens = tensor_parallel.all_to_all( parallel_state.get_expert_model_parallel_group(), hidden_states, self.input_splits, self.output_splits, ) # Unpermutation 1: AlltoAll output to output output = unpermute( permutated_local_input_tokens, self.reversed_local_input_permutation_mapping, probs=self.probs, padded_mode=self.drop_and_pad, restore_shape=self.hiddden_shape_before_permute, ) # Perform tensor parallel AlltoAll communication # output: [S*B, H/TP] -> [S*B/TP, H] if parallel_state.get_tensor_model_parallel_world_size() > 1: output = tensor_parallel.all_to_all_hp2sp(output) # Reshape the output tensor output = output.view(self.hidden_shape) return output, None