in mesh_tensorflow/transformer/heterogeneous_moe.py [0:0]
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None, context=None):
"""Local heterogenous mixture of experts.
See transformer_moe_layer_v1 in moe.py for a more detailed explanation for
a generic moe layer.
The heterogeneous mask outputted by generate_heterogeneous_expert_masks has
dimension [maximum hidden size, maximum # layers, # experts] and its shape
will overwrite the parameters moe_num_layers and moe_hidden_size in hparams.
The layer-specific mask slice is applied at each expert layer to the
activation which is [expert width, # experts]. If the heterogeneous_mask_info
is None, there is no mask applied and the code is equivalent to the
homogeneous case.
The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
of the representations of all positions in a batch of sequences.
Each position of each sequence is sent to 0-2 experts. The expert
choices and the combination weights are determined by a learned gating
function.
This function returns a small auxiliary loss that should be added to the
training loss of the model. This loss helps to balance expert usage.
Without the loss, it is very likely that a few experts will be trained and
the rest will starve.
Dimensions cheat sheet:
B: batch dim(s)
L: original sequence length
M: input depth
N: output depth
G: number of groups
S: group size
E: number of experts
C: expert capacity
Args:
inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
output_dim: a mtf.Dimension (for Transformer, this is input_dim)
hparams: model hyperparameters
train: a boolean
variable_dtype: a mtf.VariableDType
layout: optional - an input to mtf.convert_to_layout_rules
mesh_shape: optional - an input to mtf.convert_to_shape
nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
and the same dtype as inputs, consisting of ones(nonpadding)
and zeros(padding).
activation: a function.
num_microbatches: number of microbatches.
token_embeddings: a mtf.Tensor with shape
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
context: a Context.
Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
loss: a mtf scalar
Raises:
ValueError: on unrecognized hparams.moe_gating
"""
orig_inputs = inputs
experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)
if hparams.moe_heterogeneous_mask_info is not None:
tf.logging.info("moe_heterogeneous_mask_info: {}".format(
hparams.moe_heterogeneous_mask_info))
heterogeneous_mask = generate_heterogeneous_expert_masks(
hparams.moe_heterogeneous_mask_info,
hparams.moe_num_experts,
experts_dim,
mesh=inputs.mesh,
expert_width=hparams.moe_hidden_size)
# overwrite depth and width with the mask maximum dimension
hparams.moe_num_layers = heterogeneous_mask.shape[1].size
hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups is a multiple of the mesh dimension
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])
# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
#
# We then reqire num_groups to be a multiple of mesh_dim_size.
if orig_inputs.shape.dims[0].name == "outer_batch":
outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
else:
outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
orig_inputs.shape.dims[0])
# Number of MoE inputs (total number of position across batch_and_length_dims
# per replica.
n = 1
for d in batch_and_length_dims:
n *= d.size
n = n // outer_batch_dim.size
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
orig_batch_dim)
num_groups, group_size = moe._split_into_groups( # pylint: disable=protected-access
n, hparams.moe_group_size, mesh_dim_size)
# TODO(barretzoph): implementation without pylint calls?
group_size_dim = mtf.Dimension("group", group_size)
num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
# OGSM Tensor
inputs = mtf.reshape(inputs, moe_input_dims)
# Token embeddings that can be optionally used in the router for determining
# where to send tokens.
if hparams.moe_word_embed_mode is not None:
token_embeddings = mtf.cast(
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)
# Each sequence sends expert_capacity positions to each expert.
if train:
capacity_factor = hparams.moe_capacity_factor_train
else:
capacity_factor = hparams.moe_capacity_factor_eval
expert_capacity = min(
group_size_dim.size,
int((group_size_dim.size * capacity_factor) / experts_dim.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
tf.logging.info("expert_capacity: %d" % expert_capacity)
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
if nonpadding is not None:
nonpadding = mtf.zeros(
inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
if hparams.moe_gating == "top_2":
# combine_tensor,
# dispatch_tensor OG`SEC Tensors
# (G is generally split along mesh dim)
dispatch_tensor, combine_tensor, loss = moe._top_2_gating( # pylint: disable=protected-access
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "top_n":
dispatch_tensor, combine_tensor, loss = moe._top_n_gating( # pylint: disable=protected-access
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch":
dispatch_tensor, combine_tensor, loss = moe._switch_gating( # pylint: disable=protected-access
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "ntlb":
dispatch_tensor, combine_tensor, loss = moe._ntlb_gating( # pylint: disable=protected-access
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch_max":
dispatch_tensor, combine_tensor, loss = moe._switch_max_gating( # pylint: disable=protected-access
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "expert_selection":
dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating( # pylint: disable=protected-access
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
group_size_dim=group_size_dim,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
name="expert_selection_gating",
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
expert_inputs = mtf.einsum([inputs, dispatch_tensor],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim, input_dim
]))
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
d_model_split_dim
]))
# Split over batch -> split over experts
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
input_dim
]))
# Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts]
for layer_idx in range(hparams.moe_num_layers):
with tf.variable_scope("expert_layer_{}".format(layer_idx)):
res_h = 0.0
if layer_idx > 0:
res_h = expert_inputs
expert_inputs = transformer.sublayer_rms_norm(
expert_inputs, None, context)
# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")
# apply dropout
if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)
# only if heterogeneous
if hparams.moe_heterogeneous_mask_info is not None:
# Get mask for current layer by slicing heterogeneous mask
heterogeneous_mask_slice = mtf.slice(
heterogeneous_mask, layer_idx, 1, "num_expert_layers")
# Get rid of the expert layers dimension.
heterogeneous_mask_slice = mtf.reshape(
heterogeneous_mask_slice,
[heterogeneous_mask_slice.shape[0],
heterogeneous_mask_slice.shape[-1]])
h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
expert_output = mtf.layers.dense(
h, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
name="wo")
if layer_idx < (hparams.moe_num_layers - 1):
expert_output = transformer.sublayer_dropout(
expert_output, None, context)
expert_output += res_h
expert_inputs = expert_output
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))
# Split over experts -> split over batch
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output, loss * hparams.moe_loss_coef