megatron_patch/model/deepseek_v2/moe/shared_experts.py (180 lines of code) (raw):
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from copy import deepcopy
from typing import Optional
import torch
import torch.nn.functional as F
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_sharded_tensor_for_checkpoint
from ..mlp import MLP
class SharedExpertMLP(MLP):
"""
MLP layer for Shared Experts.
"""
# This stream is used when '--moe-shared-expert-overlap' is set.
# The shared experts are scheduled into this stream to be overlapped with the dispatcher.
stream = None
def __init__(self, config: TransformerConfig, spec: ModuleSpec):
config = deepcopy(config)
assert config.add_bias_linear == False, "bias is not supported in the shared experts, "
"please set '--disable-bias-linear' instead."
config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
super().__init__(config=config, submodules=spec.submodules)
self.use_shared_expert_gate = spec.params.get("gate", False)
if self.use_shared_expert_gate:
self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
if config.perform_initialization:
if get_cuda_rng_tracker().is_initialized():
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.gate_weight)
else:
config.init_method(self.gate_weight)
self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype)
setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel)
else:
self.gate_weight = None
if self.config.moe_shared_expert_overlap:
# disable TP related AG/RS communications in the linear module
for linear in [self.linear_fc1, self.linear_fc2]:
if hasattr(linear, 'parallel_mode'):
# TELinear
linear.parallel_mode = None
else:
# MCore legacy Linear
linear.explicit_expert_comm = True
# The overlapped version is splitted into some separated functions and is put inside
# the token dispatcher. These functions should be called in this order and no one can
# be skipped:
# pre_forward_comm(input)
# linear_fc1_forward_and_act()
# linear_fc2_forward()
# post_forward_comm()
# output = get_output()
#
# We use cached intermediate results to avoid messy arg passing in the dispatcher.
self.cached_fc1_input = None
self.cached_fc2_input = None
self.cached_fc2_output = None
self.cached_output = None
self.gate_score = None
if self.stream is None:
self.stream = torch.cuda.Stream()
def forward(self, hidden_states):
"""Forward function"""
output, _ = super().forward(hidden_states)
if self.use_shared_expert_gate:
logits = torch.nn.functional.linear(hidden_states, self.gate_weight)
gate_score = torch.nn.functional.sigmoid(logits)
output = output * gate_score
return output
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""Gets sharded state dict."""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
if self.use_shared_expert_gate:
name = 'gate_weight'
state_dict = self.state_dict(prefix='', keep_vars=True)
sub_sd = {
f'{prefix}{name}': make_sharded_tensor_for_checkpoint(
state_dict[name], f'{prefix}{name}', prepend_offsets=sharded_offsets
)
}
sharded_state_dict.update(sub_sd)
return sharded_state_dict
def pre_forward_comm(self, input):
"""
All Gather for SP before forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_output is None
self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
if self.use_shared_expert_gate:
logits = torch.nn.functional.linear(input, self.gate_weight)
self.gate_score = torch.nn.functional.sigmoid(logits)
if self.config.sequence_parallel:
self.cached_fc1_input = gather_from_sequence_parallel_region(
input, tensor_parallel_output_grad=True
)
else:
self.cached_fc1_input = copy_to_tensor_model_parallel_region(input)
set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max)
def linear_fc1_forward_and_act(self, overlapped_comm_output=None):
"""
Do Linear FC1 and activation function forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_fc1_input is not None
if overlapped_comm_output is not None:
set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max)
with torch.cuda.stream(self.stream):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input)
self.cached_fc1_input = None
if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(
intermediate_parallel, bias_parallel
)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
self.cached_fc2_input = intermediate_parallel
def linear_fc2_forward(self, overlapped_comm_output=None):
"""
Do Linear FC2 forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_fc2_input is not None
if overlapped_comm_output is not None:
set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max)
with torch.cuda.stream(self.stream):
# [s, b, h]
self.cached_fc2_output, _ = self.linear_fc2(self.cached_fc2_input)
self.cached_fc2_input = None
def post_forward_comm(self):
"""
Reduce scatter for SP after forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_fc2_output is not None
with torch.cuda.stream(self.stream):
if self.config.sequence_parallel:
self.cached_output = reduce_scatter_to_sequence_parallel_region(
self.cached_fc2_output
)
else:
self.cached_output = reduce_from_tensor_model_parallel_region(
self.cached_fc2_output
)
self.cached_fc2_output = None
set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max)
def get_output(self):
"""
Gets the module forward output.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_output is not None
with torch.cuda.stream(self.stream):
if self.use_shared_expert_gate:
assert self.gate_score is not None
output = self.cached_output * self.gate_score
self.gate_score = None
else:
output = self.cached_output
self.cached_output = None
torch.cuda.current_stream().wait_stream(self.stream)
return output
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
TORCH_LAST = torch.__version__.split(".")[2]
def set_tensor_grad_fn_sequence_sr(tensor, value):
"""
Set sequence_sr for the grad_fn of a tensor to control the backward order.
For older PyTorch version, do nothing (backward order is not changed).
The bigger the value is, the earlier the grad_fn is scheduled.
"""
if (
(TORCH_MAJOR > 2)
or (TORCH_MAJOR == 2 and TORCH_MINOR > 2)
or (TORCH_MAJOR == 2 and TORCH_MINOR == 2 and '+' not in TORCH_LAST)
):
# In NVIDIA PyTorch container 24.01, the PyTorch version is 2.2.0a0+81ea7a4,
# which does not contian the set_sequence_nr commit.
if tensor is not None and tensor.grad_fn is not None:
tensor.grad_fn._set_sequence_nr(value)
else:
warnings.warn(
"WARNING : PyTorch is too old to set sequence_sr and the performance may not "
"optimal. Please use PyTorch >= 2.2.0 for better performance."
)