in mesh_tensorflow/transformer/moe.py [0:0]
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None):
"""Local mixture of experts that works well on TPU.
Adapted from the paper https://arxiv.org/abs/1701.06538
Note: until the algorithm and inferface solidify, we pass in a hyperparameters
dictionary in order not to complicate the interface in mtf_transformer.py .
Once this code moves out of "research", we should pass the hyperparameters
separately.
Hyperparameters used:
hparams.moe_num_experts: number of experts
hparams.moe_hidden_size: size of hidden layer in each expert
hparams.moe_group_size: size of each "group" for gating purposes
hparams.moe_capacity_factor_train: a float
hparams.moe_capacity_factor_eval: a float
hparams.moe_gating: a string
+ all hyperparmeters used by _top_2_gating()
The number of parameters in the gating network is:
(input_dim.size * hparams.num_experts) +
The number of parameters in the experts themselves is:
(hparams.num_experts
* (input_dim.size + output_dim.size)
* hparams.moe_hidden_size)
The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
of the representations of all positions in a batch of sequences.
Each position of each sequence is sent to 0-2 experts. The expert
choices and the combination weights are determined by a learned gating
function.
This function returns a small auxiliary loss that should be added to the
training loss of the model. This loss helps to balance expert usage.
Without the loss, it is very likely that a few experts will be trained and
the rest will starve.
Several hacks are necessary to get around current TPU limitations:
- To ensure static shapes, we enforce (by truncation/padding)
that each sequence send the same number of elements to each expert.
It would make more sense to enforce this equality over the entire batch,
but due to our hacked-up gather-by-matmul implementation, we need to divide
the batch into "groups". For each group, the same number of elements
are sent to each expert.
TODO(noam): Factor this code better. We want to be able to substitute
different code for the experts themselves.
Dimensions cheat sheet:
B: batch dim(s)
L: original sequence length
M: input depth
N: output depth
G: number of groups
S: group size
E: number of experts
C: expert capacity
Args:
inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
output_dim: a mtf.Dimension (for Transformer, this is input_dim)
hparams: model hyperparameters
train: a boolean
variable_dtype: a mtf.VariableDType
layout: optional - an input to mtf.convert_to_layout_rules
mesh_shape: optional - an input to mtf.convert_to_shape
nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
and the same dtype as inputs, consisting of ones(nonpadding)
and zeros(padding).
activation: a function.
num_microbatches: number of microbatches.
token_embeddings: a mtf.Tensor with shape
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
loss: a mtf scalar
Raises:
ValueError: on unrecognized hparams.moe_gating
"""
# pylint: disable=line-too-long
#
# O outer_batch dimension can be used for expert replication, e.g.
# outer_batch=4 for placing 128 experts on 512 cores with 4 replicas of each
# expert.
#
# E.g. 16x16 basic example:
# moe_num_experts=512, num_groups=1024, batch=4096, length=256, d_model=1024
# ---
# Below ` indicates common way of splitting along mesh dimension.
#
# orig_inputs OB`LM Tensor
# Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
# v (reshaped)
# inputs OG`SM
# Shape[outer_batch=1, batch=1024, group=1024, d_model=1024]
#
# combine_tensor,
# dispatch_tensor OG`SEC
# Shape[outer_batch=1, batch=1024, group=1024, expert_unsplit=512, expert_capacity=4]
#
# (dispatched inputs)
# expert_inputs OEG`CM
# Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
# v (re-split via ReshapeOperation)
# OE`GCM
# Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
#
# (hidden representation)
# h OE`GCH
# Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, expert_hidden=8192]
#
# expert_output OE`GCM
# Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
# v (re-split via ReshapeOperation)
# OEG`CM
# Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
#
# (combined expert_output)
# output OG`SM
# Shape[outer_batch=1, batch=1024, group=1024, d_model=1024
# v (reshape)
# OB`LM
# Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
#
# pylint: enable=line-too-long
orig_inputs = inputs
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups is a multiple of the mesh dimension
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])
# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
#
# We then reqire num_groups to be a multiple of mesh_dim_size.
if orig_inputs.shape.dims[0].name == "outer_batch":
outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
else:
outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
orig_inputs.shape.dims[0])
# Number of MoE inputs (total number of position across batch_and_length_dims
# per replica.
n = 1
for d in batch_and_length_dims:
n *= d.size
n = n // outer_batch_dim.size
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
orig_batch_dim)
num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
mesh_dim_size)
group_size_dim = mtf.Dimension("group", group_size)
num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
# OGSM Tensor
inputs = mtf.reshape(inputs, moe_input_dims)
# Token embeddings that can be optionally used in the router for determining
# where to send tokens.
if hparams.moe_word_embed_mode is not None:
token_embeddings = mtf.cast(
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)
# Each sequence sends expert_capacity positions to each expert.
if train:
capacity_factor = hparams.moe_capacity_factor_train
else:
capacity_factor = hparams.moe_capacity_factor_eval
expert_capacity = min(
group_size_dim.size,
int((group_size_dim.size * capacity_factor) / experts_dim.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
tf.logging.info("expert_capacity: %d" % expert_capacity)
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
if nonpadding is not None:
nonpadding = mtf.zeros(
inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
if hparams.moe_gating == "top_2":
# combine_tensor,
# dispatch_tensor OG`SEC Tensors
# (G is generally split along mesh dim)
dispatch_tensor, combine_tensor, loss = _top_2_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "top_n":
dispatch_tensor, combine_tensor, loss = _top_n_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch":
dispatch_tensor, combine_tensor, loss = _switch_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "ntlb":
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch_max":
dispatch_tensor, combine_tensor, loss = _switch_max_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "expert_selection":
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
group_size_dim=group_size_dim,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
name="expert_selection_gating",
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
expert_inputs = mtf.einsum([inputs, dispatch_tensor],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim, input_dim
]))
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
d_model_split_dim
]))
# Split over batch -> split over experts
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
input_dim
]))
# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")
if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)
def _compute_output(hidden, layer_name):
"""Compute the output of the attention layer from the hidden vector."""
expert_output = mtf.layers.dense(
hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
name=layer_name)
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
d_model_split_dim = mtf.Dimension(
"d_model_split", expert_output.shape[-1].size)
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))
# Split over experts -> split over batch
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output
if hparams.moe_use_experts_attention:
# We share k_h and v_h with no degradation in performance
q_h, k_h = h, h
outputs = []
q = _compute_output(q_h, layer_name="q_wo")
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef