in mesh_tensorflow/transformer/moe.py [0:0]
def _top_n_gating(
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
hparams, train, variable_dtype, importance=None, name="top_n_gating",
num_microbatches=None, token_embeddings=None):
"""Compute generalization of top-2 gating for mixture-of-experts.
Hyperparameters used:
hparams.moe_use_second_place_loss: a boolean
hparams.moe_second_policy_train: a string
hparams.moe_second_policy_eval: a string
hparams.moe_second_threshold: a float
hparams.moe_top_n_num_experts_per_token: an int
Tensor shapes are largely the same as in top_2 gating, so see that docstring
for more details.
Args:
inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
outer_expert_dims: an optional list of dimensions. This is for the case
where we are at an inner level of a hierarchical MoE.
experts_dim: a Dimension (the number of experts)
expert_capacity_dim: a Dimension (number of examples per group per expert)
hparams: model hyperparameters.
train: a boolean
variable_dtype: a mtf.VariableDType
importance: an optional tensor with shape [<batch_dims>, group_size_dim]
name: an optional string
num_microbatches: number of microbatches.
token_embeddings: an optional tensor with shape
[<batch_dims>, group_size_dim, input_dim] that is the input
word embeddings.
Returns:
dispatch_tensor: a Tensor with shape
[<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
combine_tensor: a Tensor with shape
[<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
loss: a mtf scalar
Raises:
ValueError: on illegal hyperparameters
"""
group_size_dim, unused_input_dim = inputs.shape.dims[-2:]
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)
gate_logits = mtf.layers.dense(
gate_inputs, experts_dim, use_bias=False,
expert_dims=outer_expert_dims,
variable_dtype=variable_dtype,
name=name)
raw_gates = mtf.softmax(gate_logits, experts_dim)
expert_capacity_f = float(expert_capacity_dim.size)
# Used for aux loss.
density_1_proxy = raw_gates
if importance is not None:
density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
# Loop over the get the top-n tokens and their masks.
gates = []
masks = []
indexes = []
# Tensor that contains all but the top-n highest experts for each token.
gates_without_top_n = raw_gates
gates_without_top_1 = None # Used for second place loss
for n in range(hparams.moe_top_n_num_experts_per_token):
# [batch, group]
gate_n, index_n = mtf.top_1(gates_without_top_n, experts_dim)
# [batch, group, experts]
mask_n = mtf.one_hot(index_n, experts_dim, dtype=raw_gates.dtype)
if importance is not None:
mask_n *= mtf.to_float(mtf.greater(importance, 0.0))
gate_n *= mtf.to_float(mtf.greater(importance, 0.0))
gates_without_top_n *= (1.0 - mask_n)
# Used for second place loss.
if n == 1:
gates_without_top_1 = gates_without_top_n
gates.append(gate_n)
masks.append(mask_n)
indexes.append(index_n)
if len(gates) > 1:
# All gates probs are normalized over the top-n tokens.
denom = mtf.add_n(gates) + 1e-9
gates = [gate / denom for gate in gates]
# BALANCING LOSSES
# shape = [batch, experts]
# We want to equalize the fraction of the batch assigned to each expert.
mask_1 = masks[0] # Mask for top-1 token.
density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
# Something continuous that is correlated with what we want to equalize.
density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim)
loss = (mtf.reduce_mean(density_1_proxy * density_1)
* float(experts_dim.size * experts_dim.size))
# TODO(barretzoph): Add in options for aux losses for n > 2.
if hparams.moe_use_second_place_loss:
pass
# Also add a loss to encourage all experts to be used equally also as the
# second-place expert. Experimentally, this seems to be a wash.
# We want to equalize the fraction of the batch assigned to each expert:
density_2 = mtf.reduce_mean(masks[2], reduced_dim=group_size_dim)
# As a proxy for density_2, we renormalize the raw gates after the top one
# has been removed.
normalized = gates_without_top_1 / (
mtf.reduce_sum(gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim)
loss_2 = (mtf.reduce_mean(density_2_proxy * density_2)
* float(experts_dim.size * experts_dim.size))
loss += loss_2 * 0.5
if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches
# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
importance)
mtf.scalar_summary(name + "/z_loss", z_loss)
loss += (hparams.moe_z_loss * z_loss)
# Depending on the policy in the hparams, we may drop out some of the
# second-place experts.
def _update_mask_based_on_gate_value(gate_n, mask_n):
"""Update the mask based in the policy and the threshold for n>1.
Args:
gate_n: normalized router probability for the nth highest expert.
mask_n: boolean one-hot tensor that keeps track of the nth expert to
send to each toke. This also masks away tokens that will not be routed.
Returns:
An altered mask_n that will mask out any top-n token that doesn't follow
the second_policy method and threshold.
"""
if train:
policy = hparams.moe_second_policy_train
threshold = hparams.moe_second_threshold_train
else:
policy = hparams.moe_second_policy_eval
threshold = hparams.moe_second_threshold_eval
if policy == "all":
# Use nth-place experts for all examples.
pass
elif policy == "none":
# Never use nth-place experts for all examples.
mask_n = mtf.zeros_like(mask_n)
elif policy == "threshold":
# Use nth-place experts if gate_n > threshold.
mask_n *= mtf.to_float(mtf.greater(gate_n, threshold))
elif policy == "random":
# Use nth-place experts with probablity min(1.0, gate_n / threshold).
mask_n *= mtf.to_float(
mtf.less(mtf.random_uniform(gate_n.mesh, gate_n.shape),
gate_n / max(threshold, 1e-9)))
else:
raise ValueError("Unknown policy %s" % policy)
return mask_n
# Now update masks for n>1 to reflect how these additional tokens should be
# routed according to their corresponding policies.
# Only update for n>1 as we always want to route the top-1 token.
for i in range(1, len(masks)):
masks[i] = _update_mask_based_on_gate_value(gates[i], masks[i])
def _compute_top_n_mask(gate_n, mask_n, index_n, prev_mask_count):
# This is the position within the expert's mini-batch for this sequence.
position_in_expert_n = (
mtf.cumsum(mask_n, group_size_dim, exclusive=True) + prev_mask_count)
# Mask out tokens that should not be routed.
position_in_expert_n *= mask_n
# Remove the elements that don't fit. [batch, group, experts]
mask_n *= mtf.to_float(mtf.less(position_in_expert_n, expert_capacity_f))
# [batch, experts]
# How many examples in this sequence go to this expert.
mask_n_count = mtf.reduce_sum(mask_n, reduced_dim=group_size_dim)
# Keep running sum of total tokens sent to each expert.
prev_mask_count += mask_n_count
# [batch, group] - mostly ones, but zeros where something didn't fit.
mask_n_flat = mtf.reduce_sum(mask_n, reduced_dim=experts_dim)
# Weight assigned to nth expert. [batch, group]
gate_n *= mask_n_flat
# [batch, group]
position_in_expert_n = mtf.reduce_sum(
position_in_expert_n, reduced_dim=experts_dim)
partial_combine_tensor = (
gate_n * mask_n_flat
* mtf.one_hot(index_n, experts_dim)
* mtf.one_hot(mtf.to_int32(position_in_expert_n), expert_capacity_dim))
return prev_mask_count, partial_combine_tensor
# [batch, experts]
# How many examples in this group go to each expert. This starts at zero.
prev_mask_count = 0.0
partial_combine_tensors = []
for gate_n, mask_n, index_n in zip(gates, masks, indexes):
prev_mask_count, partial_combine_tensor = _compute_top_n_mask(
gate_n, mask_n, index_n, prev_mask_count)
partial_combine_tensors.append(partial_combine_tensor)
combine_tensor = mtf.add_n(partial_combine_tensors)
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
loss = mtf.cast(loss, inputs.dtype)
dispatch_tensor = mtf.cast(
mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)
return dispatch_tensor, combine_tensor, loss