in mesh_tensorflow/transformer/moe.py [0:0]
def transformer_moe_layer_v2(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, num_microbatches=None):
"""2-level mixture of experts.
Adapted from the paper https://arxiv.org/abs/1701.06538
Note: until the algorithm and inferface solidify, we pass in a hyperparameters
dictionary in order not to complicate the interface in mtf_transformer.py .
Once this code moves out of "research", we should pass the hyperparameters
separately.
Hyperparameters used:
hparams.moe_num_experts: number of experts
hparams.moe_hidden_size: size of hidden layer in each expert
hparams.moe_group_size: size of each "group" for gating purposes
hparams.moe_capacity_factor_train: a float
hparams.moe_capacity_factor_eval: a float
hparams.moe_capacity_factor_second_level: a float
hparams.moe_gating: a string
+ all hyperparmeters used by _top_2_gating()
One set of params for experts in first level and different of hparams
per expert in the second level.
The number of parameters in the gating network is:
(input_dim.size * (hparams.num_experts) +
(moe_hidden_size * hparams.num_experts) * hparams.num_experts
The number of parameters in the experts themselves is:
(hparams.num_experts
* (input_dim.size + output_dim.size)
* hparams.moe_hidden_size)
The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
of the representations of all positions in a batch of sequences.
Each position of each sequence is sent to 0-3 experts. The expert
choices and the combination weights are determined by a learned gating
function.
This function returns a small auxiliary loss that should be added to the
training loss of the model. This loss helps to balance expert usage.
Without the loss, it is very likely that a few experts will be trained and
the rest will starve.
Several hacks are necessary to get around current TPU limitations:
- To ensure static shapes, we enforce (by truncation/padding)
that each sequence send the same number of elements to each expert.
It would make more sense to enforce this equality over the entire batch,
but due to our hacked-up gather-by-matmul implementation, we need to divide
the batch into "groups". For each group, the same number of elements
are sent to each expert.
TODO(noam): Factor this code better. We want to be able to substitute
different code for the experts themselves.
Dimensions cheat sheet:
a, b: batch size
l: original sequence length
m: input depth
n: output depth
g, h: number of groups
s, t: group size
x, y: number of experts
c, d: expert capacity
input: [a0, b1, l, m]
input: [a0, g1, s, m]
dispatch_tensor_x: [a0, g1, s, x, c]
expert_input: [a0, g1, x, c, m]
alltoall: [a0, g, x1, c, m]
alltoall: [a0, g, x1, c, m]
transpose: [x1, a0, g, c, m]
reshape: [x1, h0, s, m]
assignment2: [x1, h0, t, y, d]
expert_input2: [x1, h0, y, d, m]
alltoall: [x1, h, y0, d, m]
...
reverse of that
gating params 0: [m, x]
gating params 1: [x1, m, y]
expert params:
[x1, y0, m, hidden]
[x1, y0, hidden, n]
Args:
inputs: a mtf.Tensor with shape [a, b, l, m]
output_dim: a mtf.Dimension (for Transformer, this is input_dim)
hparams: model hyperparameters
train: a boolean
variable_dtype: a mtf.VariableDType
layout: optional - an input to mtf.convert_to_layout_rules
mesh_shape: optional - an input to mtf.convert_to_shape
nonpadding: an optional mtf.Tensor with shape [a, b, l]
and the same dtype as inputs, consisting of ones(nonpadding)
and zeros(padding).
num_microbatches: number of microbatches.
Returns:
outputs: a Tensor with shape [a, b, l, n]
loss: a mtf scalar
Raises:
ValueError: on unrecognized hparams.moe_gating
"""
if nonpadding is not None:
nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1],
dtype=inputs.dtype) + nonpadding
insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
if insert_outer_batch_dim:
inputs = mtf.reshape(
inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims)
assert len(hparams.moe_num_experts) == 2
a0, b1, l, m = inputs.shape.dims
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
n = output_dim
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups (g.size) is a multiple of the mesh dimension
# over which those groups are split.
num_groups, group_size = _split_into_groups(
b1.size * l.size, hparams.moe_group_size,
mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1))
g1 = mtf.Dimension(b1.name, num_groups)
g = mtf.Dimension(b1.name + "_unsplit", g1.size)
s = mtf.Dimension("group_size_x", group_size)
# Each sequence sends (at most?) expert_capacity positions to each expert.
# Static expert_capacity dimension is needed for expert batch sizes
if train:
capacity_factor = hparams.moe_capacity_factor_train
else:
capacity_factor = hparams.moe_capacity_factor_eval
expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
c = mtf.Dimension("expert_capacity_x", expert_capacity)
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups (h.size) is a multiple of the mesh dimension
# over which those groups are split.
num_groups, group_size = _split_into_groups(
a0.size * g.size * c.size,
hparams.moe_group_size,
mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0))
t = mtf.Dimension("group_size_y", group_size)
h0 = mtf.Dimension(a0.name, num_groups)
h = mtf.Dimension(a0.name + "_unsplit", h0.size)
expert_capacity = min(
t.size,
int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
d = mtf.Dimension("expert_capacity_y", expert_capacity)
# First level of expert routing
# Reshape the inner batch size to a multiple of group_dim g1 and
# group_size_dim s.
inputs = mtf.reshape(inputs, [a0, g1, s, m])
if nonpadding is not None:
nonpadding = mtf.reshape(nonpadding, [a0, g1, s])
# Get the assignments for the first level.
# dispatch_tensor_x has shape [a0, g1, s, x, c]
if hparams.moe_gating == "top_2":
dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=x,
expert_capacity_dim=c,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
name="outer_gating",
importance=nonpadding,
num_microbatches=num_microbatches)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
# Now create expert_inputs based on the assignments.
# put num_experts dimension first to make split easier in alltoall
expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m])
# we construct an "importance" Tensor for the inputs to the second-level
# gating. The importance of an input is 1.0 if it represents the
# first-choice expert-group and 0.5 if it represents the second-choice expert
# group. This is used by the second-level gating.
importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
importance = 0.5 * (
mtf.to_float(mtf.greater(importance, 0.5)) +
mtf.to_float(mtf.greater(importance, 0.0)))
# First level, all to all. Here we change the split dimension from g1 to x1.
expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape(
[x1, a0, g, c, m]))
importance = mtf.reshape(importance, [x1, a0, g, c])
# Second level of expert routing
# Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
# and group_size_dim t.
inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
importance = mtf.reshape(importance, [x1, h0, t])
# Get the assignments for the second level.
# dispatch_tensor_y has shape [x1, h0, t, y, d]
if hparams.moe_gating == "top_2":
dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
inputs=inputs_y,
outer_expert_dims=[x1],
experts_dim=y,
expert_capacity_dim=d,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=importance,
name="inner_gating",
num_microbatches=num_microbatches)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
# Now create expert_inputs based on the assignments.
# put num_experts dimension first to make split easier in alltoall
expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m])
# Second level, all to all. Here we change the split dimension from h0 to y0.
expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape(
[y0, x1, h, d, m]))
hidden_output = mtf.layers.dense(
expert_inputs_y, hidden_dim, expert_dims=[y0, x1],
reduced_dims=expert_inputs_y.shape.dims[-1:],
activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype,
name="wi")
expert_output = mtf.layers.dense(
hidden_output, output_dim, expert_dims=[y0, x1],
reduced_dims=hidden_output.shape.dims[-1:],
use_bias=False, variable_dtype=variable_dtype,
name="wo")
# NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
# expert_output has shape [y0, x1, h, d, n]
# alltoall
expert_output = mtf.reshape(expert_output, mtf.Shape(
[y, x1, h0, d, n]))
# combine results from inner level
output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])
# Reshape the combined tensor from inner level to now contain outer_batch_dim
# a0 and group_dim g
output = mtf.reshape(output_y, [x1, a0, g, c, n])
# alltoall from expert_dim x to group_dim g1
expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))
# combine results from outer level
output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])
# Reshape the combined tensor to now contain inner_batch_dim
# b1 and the original sequence length
output = mtf.reshape(output_x, [a0, b1, l, n])
if insert_outer_batch_dim:
output = mtf.reshape(output, [b1, l, n])
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef