megatron_patch/model/mixtral_bak/moe/token_dispatcher.py (172 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import List
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.parallel_state import get_tensor_and_expert_parallel_group
from ..transformer_config import TransformerConfig
class MoETokenDispatcher:
"""
MoE Token Dispatcher
"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.config = config
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, indices: torch.Tensor,
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
indices (torch.Tensor): indices tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(
self, expert_output: torch.Tensor, scores: torch.Tensor, indices: torch.Tensor,
):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
scores (torch.Tensor): Each token's score with each expert.
indices (torch.Tensor): The indices used to reorder the expert output.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
class MoEDroplessTokenDispatcher(MoETokenDispatcher):
"""
Token dispatcher without token dropping.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the zero token dropping router.
Args:
num_local_experts (int): The number of experts in the local process/group.
local_expert_indices (List[int]): The indices of the experts that are local
to the current process. These indices identify the experts within the
larger, global set of experts in a distributed setup.
config (TransformerConfig): An instance of TransformerConfig that contains
various configuration settings for the model such as the number of
experts, model parallelism settings, and other relevant parameters.
Returns:
None
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
self.local_expert_indices = local_expert_indices
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
def gather_indices(self, local_indices: torch.Tensor):
"""
Gather tensors and concatenate along the first dimension.
Args:
local_indices (torch.Tensor): Tensor of indices on the local device.
Returns:
torch.Tensor: Tensor containing the concatenated indices from all devices.
"""
group = get_tensor_and_expert_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return local_indices
dim_size = list(local_indices.size())
dim_size[0] = dim_size[0] * world_size
# TODO pre allocate memory
output = torch.empty(
dim_size, dtype=local_indices.dtype, device=torch.cuda.current_device()
)
torch.distributed._all_gather_base(output, local_indices.contiguous(), group=group)
return output
def token_permutation(
self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
each device receives all of the tokens assigned to its local set of experts
in its local HBM.
(2) Permute the tokens locally so that they are grouped by their expert
assignment. After the stage (1), the tokens are grouped by which device
they came from. We re-order them locally for subsequent efficient computation.
Args:
hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize]
max_prob: probs of token assignment to local experts.
max_ind: token assignment to local experts.
Returns:
permuted_local_hidden_states: Permutation of tokens to local experts group.
tokens_per_expert: the number of tokens each local expert to process.
indices: The indices of `local_indices` (which holds the un-sorted expert
indices of tokens that local expert can process) that give its sorted order along dim 0.
global_local_map (optional): 2D tensor. A mask of mapping between global and local tokens where each
element is True if it's between the local_expert_indices. Only useful
when cross device token permutation is enabled and **AllGahter** is performed.
"""
self.hidden_shape = hidden_states.shape
# [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Permute the tokens across the expert parallel devices.
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
# [S*B/TP, H] -> [S*B, H]
global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
hidden_states
)
with torch.no_grad():
global_indices = self.gather_indices(max_ind)
# Create a mask of mapping between global and local tokens where each
# element is True if it's between the local_expert_indices
global_local_map = (global_indices >= self.local_expert_indices[0]) & (
global_indices <= self.local_expert_indices[-1]
)
local_indices = global_indices.masked_select(global_local_map)
if self.router_topk > 1: # k > 1
global_probs = self.gather_indices(max_prob)
local_probs = global_probs.masked_select(global_local_map)
else:
local_probs = max_prob
# Reshape global_local_map to be compatible with Tensor.gather
global_local_map = global_local_map.nonzero()[:, 0]
global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
local_hidden_states = torch.gather(global_hidden_states, 0, global_local_map)
else:
if self.router_topk > 1:
global_local_map = torch.ones_like(max_ind).bool()
local_indices = max_ind.masked_select(global_local_map)
local_probs = max_prob.masked_select(global_local_map)
global_local_map = global_local_map.nonzero()[:, 0]
global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
local_hidden_states = torch.gather(hidden_states, 0, global_local_map)
else:
local_indices = max_ind
local_probs = max_prob
local_hidden_states = hidden_states
global_local_map = None
with torch.no_grad():
# The indices of local_indices that give its sorted order along dim 0.
indices = torch.argsort(local_indices, dim=0)
tokens_per_expert = torch.histc(
local_indices,
bins=self.num_local_experts,
min=self.local_expert_indices[0],
max=self.local_expert_indices[-1],
)
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
# Stage2: permute the tokens locally so that they are grouped by their expert assignment
# Reshape indices to be compatible with Tensor.gather
indices = indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
permuted_local_hidden_states = torch.gather(local_hidden_states, 0, indices)
return (
permuted_local_hidden_states,
tokens_per_expert,
local_probs,
indices,
global_local_map,
)
def token_unpermutation(
self,
hidden_states: torch.Tensor,
scores: torch.Tensor,
indices: torch.Tensor,
global_local_map: torch.Tensor = None,
bias: torch.Tensor = None,
):
"""
Reverse process of `dispatch()` which permutes the ouput of local
experts locallay and across expert parallel rank into the original order to
produce the final output.
Args:
hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize],
ouput of local experts.
scores: 2D tensor of the probs of token assignment to local experts.
indices: 2D tensor of the indices of `local_indices` (which holds the un-sorted expert
indices of tokens that local expert can process) that give its sorted order along dim 0.
global_local_map (optional): 2D tensor, a mask of mapping between global and local tokens where each
element is True if it's between the local_expert_indices. Only useful
when cross device token permutation is enabled and **AllGather** is performed.
bias (optional): The bias tensor.
Returns:
output_total: un-permuted updated hidden states output from all local experts
with shape of [SeqLen/TP, MBS, HiddenSize]
"""
# Stage1: unpermute the tokens and bias locally respectively.
scores = scores.to(dtype=hidden_states.dtype)
unpermuted_local_hidden = torch.zeros_like(hidden_states)
assert indices.shape == hidden_states.shape
unpermuted_local_hidden = unpermuted_local_hidden.scatter(0, indices, hidden_states)
# Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
if self.router_topk > 1:
unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
unpermuted_local_bias = None
if self.add_bias:
assert bias is not None
unpermuted_local_bias = torch.zeros_like(hidden_states)
assert indices.shape == bias.shape
unpermuted_local_bias = unpermuted_local_bias.scatter(0, indices, bias)
if self.router_topk > 1:
unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
output_total = unpermuted_local_hidden
output_bias_total = unpermuted_local_bias
# Unpermute the tokens across expert parallel devices.
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
assert global_local_map is not None, "global_local_map is necessary for `AllGather`."
ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size()
# hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP)
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape, dtype=hidden_states.dtype, device=torch.cuda.current_device()
)
# Reshape global_local_map to be compatible with Tensor.scatter
assert global_local_map.shape == unpermuted_local_hidden.shape
unpermuted_global_hidden = unpermuted_global_hidden.scatter_add(
0, global_local_map, unpermuted_local_hidden
)
output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_hidden
)
if self.add_bias:
# Unpermute the bias across expert parallel devices.
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
unpermuted_global_bias = unpermuted_global_bias.scatter_add(
0, global_local_map, unpermuted_local_bias
)
output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_bias
)
# bias is duplicated across tensor parallelism ranks;
# reduce scatter reduces bias across tensor parallel_ranks
output_bias_total = (
output_bias_total / parallel_state.get_tensor_model_parallel_world_size()
)
else:
if self.router_topk > 1:
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape,
dtype=hidden_states.dtype,
device=torch.cuda.current_device(),
)
output_total = unpermuted_global_hidden.scatter_add(
0, global_local_map, unpermuted_local_hidden
)
if self.add_bias:
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
output_bias_total = unpermuted_global_bias.scatter_add(
0, global_local_map, unpermuted_local_bias
)
if self.router_topk == 1:
output_total = output_total * scores
output_total = output_total.view(self.hidden_shape)
if self.add_bias:
assert output_bias_total is not None
if self.router_topk == 1:
output_bias_total = output_bias_total * scores
output_bias_total = output_bias_total.view(self.hidden_shape)
else:
output_bias_total = None
return output_total, output_bias_total