in mesh_tensorflow/transformer/moe.py [0:0]
def _top_2_gating(
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
hparams, train, variable_dtype, importance=None, name="top_2_gating",
num_microbatches=None, token_embeddings=None):
"""Compute gating for mixture-of-experts in TensorFlow.
Note: until the algorithm and inferface solidify, we pass in a hyperparameters
dictionary in order not to complicate the interface in mtf_transformer.py .
Once this code moves out of "research", we should pass the hyperparameters
separately.
Hyperparameters used:
hparams.moe_use_second_place_loss: a boolean
hparams.moe_second_policy_train: a string
hparams.moe_second_policy_eval: a string
hparams.moe_second_threshold: a float
The returned forward assignment is a tensor used to map (via einsum) from the
inputs to the expert_inputs. Likewise, the returned combine_tensor is
used to map (via einsum) from the expert outputs to the outputs. Both the
forward and backward assignments are mostly zeros. The shapes of the tensors
are as follows.
inputs: [<batch_dims>, group_size_dim, input_dim]
importance: [<batch_dims>, group_size_dim]
dispatch_tensor:
[<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
expert_inputs:
[<batch_dims>, experts_dim, expert_capacity_dim, input_dim]
expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
combine_tensor:
[<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
outputs: [<batch_dims>, group_size_dim, output_dim]
"importance" is an optional tensor with one floating-point value for each
input vector. If the importance of an input is 1.0, then we send it to
up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most
one expert. If importance == 0.0, then we send it to no experts.
We use "importance" at the second-level gating function of a hierarchical
mixture of experts. Inputs to the first-choice expert-group get importance
1.0. Inputs to the second-choice expert group get importance 0.5.
Inputs that represent padding get importance 0.0.
Args:
inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
outer_expert_dims: an optional list of dimensions. This is for the case
where we are at an inner level of a hierarchical MoE.
experts_dim: a Dimension (the number of experts)
expert_capacity_dim: a Dimension (number of examples per group per expert)
hparams: model hyperparameters.
train: a boolean
variable_dtype: a mtf.VariableDType
importance: an optional tensor with shape [<batch_dims>, group_size_dim]
name: an optional string
num_microbatches: number of microbatches.
token_embeddings: an optional tensor with shape
[<batch_dims>, group_size_dim, input_dim] that is the input
word embeddings.
Returns:
dispatch_tensor: a Tensor with shape
[<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
combine_tensor: a Tensor with shape
[<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
loss: a mtf scalar
Raises:
ValueError: on illegal hyperparameters
"""
group_size_dim, unused_input_dim = inputs.shape.dims[-2:]
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)
gate_logits = mtf.layers.dense(
gate_inputs, experts_dim, use_bias=False,
expert_dims=outer_expert_dims,
variable_dtype=variable_dtype,
name=name)
raw_gates = mtf.softmax(gate_logits, experts_dim)
expert_capacity_f = float(expert_capacity_dim.size)
# FIND TOP 2 EXPERTS PER POSITON
# Find the top expert for each position. shape=[batch, group]
gate_1, index_1 = mtf.top_1(raw_gates, experts_dim)
# [batch, group, experts]
mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
density_1_proxy = raw_gates
if importance is not None:
mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
gates_without_top_1 = raw_gates * (1.0 - mask_1)
# [batch, group]
gate_2, index_2 = mtf.top_1(gates_without_top_1, experts_dim)
# [batch, group, experts]
mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
if importance is not None:
mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))
denom = gate_1 + gate_2 + 1e-9
gate_1 /= denom
gate_2 /= denom
# BALANCING LOSSES
# shape = [batch, experts]
# We want to equalize the fraction of the batch assigned to each expert
density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
# Something continuous that is correlated with what we want to equalize.
density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim)
loss = (mtf.reduce_mean(density_1_proxy * density_1)
* float(experts_dim.size * experts_dim.size))
if hparams.moe_use_second_place_loss:
# Also add a loss to encourage all experts to be used equally also as the
# second-place expert. Experimentally, this seems to be a wash.
# We want to equalize the fraction of the batch assigned to each expert:
density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
# As a proxy for density_2, we renormalize the raw gates after the top one
# has been removed.
normalized = gates_without_top_1 / (
mtf.reduce_sum(gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim)
loss_2 = (mtf.reduce_mean(density_2_proxy * density_2)
* float(experts_dim.size * experts_dim.size))
loss += loss_2 * 0.5
if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches
# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
importance)
mtf.scalar_summary(name + "/z_loss", z_loss)
loss += (hparams.moe_z_loss * z_loss)
# Depending on the policy in the hparams, we may drop out some of the
# second-place experts.
if train:
policy = hparams.moe_second_policy_train
threshold = hparams.moe_second_threshold_train
else:
policy = hparams.moe_second_policy_eval
threshold = hparams.moe_second_threshold_eval
if policy == "all":
# Use second-place experts for all examples.
pass
elif policy == "none":
# Never use second-place experts for all examples.
mask_2 = mtf.zeros_like(mask_2)
elif policy == "threshold":
# Use second-place experts if gate_2 > threshold.
mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
elif policy == "random":
# Use second-place experts with probablity min(1.0, gate_2 / threshold).
mask_2 *= mtf.to_float(
mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
gate_2 / max(threshold, 1e-9)))
else:
raise ValueError("Unknown policy %s" % policy)
# COMPUTE ASSIGNMENT TO EXPERTS
# [batch, group, experts]
# This is the position within the expert's mini-batch for this sequence
position_in_expert_1 = mtf.cumsum(
mask_1, group_size_dim, exclusive=True) * mask_1
# Remove the elements that don't fit. [batch, group, experts]
mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
# [batch, experts]
# How many examples in this sequence go to this expert
mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
# [batch, group] - mostly ones, but zeros where something didn't fit
mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
# [batch, group]
position_in_expert_1 = mtf.reduce_sum(
position_in_expert_1, reduced_dim=experts_dim)
# Weight assigned to first expert. [batch, group]
gate_1 *= mask_1_flat
# [batch, group, experts]
position_in_expert_2 = (
mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
position_in_expert_2 *= mask_2
mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
# mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
gate_2 *= mask_2_flat
position_in_expert_2 = mtf.reduce_sum(
position_in_expert_2, reduced_dim=experts_dim)
if train:
# Gate entropy.
if importance is not None:
raw_gates *= mtf.to_float(mtf.greater(importance, 0.0))
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
reduced_dim=experts_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)
# Mean top-1 and top-2 normalized gate probabilities.
if importance is not None:
gate_2 *= mtf.to_float(mtf.greater(importance, 0.0))
mtf.scalar_summary("top1_gate_normalized", mtf.reduce_mean(gate_1))
mtf.scalar_summary("top2_gate_normalized", mtf.reduce_mean(gate_2))
top1_routed = mtf.reduce_sum(mask_1_flat)
top2_routed = mtf.reduce_sum(mask_2_flat)
importance = mtf.cast(importance, dtype=top1_routed.dtype)
# What fraction of the top-1 and top-2 tokens are being routed to any
# expert.
mtf.scalar_summary("top1_fraction_routed",
top1_routed / mtf.reduce_sum(importance))
mtf.scalar_summary("top2_fraction_routed",
top2_routed / mtf.reduce_sum(importance))
# One or zero if that token got routed anywhere.
total_routed = mtf.reduce_sum(mtf.minimum(
mask_1_flat + mask_2_flat, mtf.ones_like(top1_routed)))
mtf.scalar_summary("all_fraction_routed",
total_routed / mtf.reduce_sum(importance))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
# Log what fraction of tokens are going to each expert.
def _log_per_expert_fraction(mask, name):
# mask: [batch, group, experts]
tokens_per_expert = mtf.reduce_sum(mask, output_shape=[experts_dim])
total_routed = mtf.reduce_sum(tokens_per_expert)
expert_fraction = mtf.to_float(tokens_per_expert / total_routed)
split_fractions = mtf.split(
expert_fraction,
split_dim=experts_dim,
num_or_size_splits=experts_dim.size)
for fraction in split_fractions:
mtf.scalar_summary(name + "_experts/" + fraction.name.replace(":", "/"),
mtf.reduce_mean(fraction))
_log_per_expert_fraction(mask_1, "top1")
_log_per_expert_fraction(mask_2, "top2")
_log_per_expert_fraction(mask_1 + mask_2, "all")
# [batch, group, experts, expert_capacity]
combine_tensor = (
gate_1 * mask_1_flat
* mtf.one_hot(index_1, experts_dim)
* mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
gate_2 * mask_2_flat
* mtf.one_hot(index_2, experts_dim)
* mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
loss = mtf.cast(loss, inputs.dtype)
dispatch_tensor = mtf.cast(
mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)
return dispatch_tensor, combine_tensor, loss