in mesh_tensorflow/transformer/moe.py [0:0]
def _expert_selection_gating(
inputs, outer_expert_dims, experts_dim, group_size_dim,
expert_capacity_dim, hparams, train, variable_dtype, importance=None,
name="expert_selection_gating", num_microbatches=None,
normalize_by_num_experts_routed=True, token_embeddings=None):
"""Compute gating where each expert chooses what tokens it wants."""
# Select the randomization policy.
if train:
policy = hparams.moe_switch_policy_train
else:
policy = hparams.moe_switch_policy_eval
# The internals of this function run in float32 otherwise instabilities
# can occur.
gate_inputs = mtf.to_float(inputs)
# Input perturbations for exploration.
if policy == "input_dropout":
gate_inputs = mtf.dropout(gate_inputs, is_training=train,
keep_prob=1.0 - hparams.moe_switch_dropout)
elif train and policy == "input_jitter":
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
hparams.moe_switch_jitter)
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)
# Compute expert logits for each token.
# gate_logits shape: [outer_batch, batch, group, expert_unsplit]
gate_logits = mtf.layers.dense(
gate_inputs,
experts_dim,
use_bias=False,
expert_dims=outer_expert_dims,
variable_dtype=variable_dtype,
name=name)
# Set tokens to -inf before softmax if importance is zero as softmax is
# normalized over all tokens in the group.
if importance is not None:
gate_logits += mtf.cast(
mtf.equal(importance, 0.0), dtype=gate_logits.dtype) * -1e9
raw_gates = mtf.softmax(gate_logits, reduced_dim=group_size_dim)
# expert_gate_probs shape:
# [outer_batch, batch, expert_unsplit, expert_capacity]
# expert_gate_indices shape:
# [outer_batch, batch, expert_unsplit, expert_capacity]
expert_gate_probs, expert_gate_indices = mtf.top_k(
raw_gates, reduced_dim=group_size_dim, k_dim=expert_capacity_dim)
# dispatch_tensor shape:
# [outer_batch, batch, expert_unsplit, expert_capacity, group]
dispatch_tensor = mtf.one_hot(
expert_gate_indices, group_size_dim, dtype=raw_gates.dtype)
# combine_tensor shape:
# [outer_batch, batch, expert_unsplit, expert_capacity, group]
combine_tensor = dispatch_tensor * expert_gate_probs
# Tokens will be aggregated across many experts and will not
# be normalized. This could be an issue, so might want to normalize by the
# number of experts each token is sent to.
if normalize_by_num_experts_routed:
num_experts_routed = mtf.reduce_sum(
dispatch_tensor,
output_shape=(dispatch_tensor.shape[:2] + [group_size_dim]))
combine_tensor /= mtf.maximum(num_experts_routed, 1.0)
################### Compute the load balancing loss ###################
# Push `aggregated_group_probs` of size `group` (which sums to num_experts)
# to be uniform.
# aggregated_group_probs shape: [outer_batch, batch, group]
# importance shape: [outer_batch, batch, group]
aggregated_group_probs = mtf.reduce_mean(raw_gates, reduced_dim=experts_dim)
if importance is not None:
aggregated_group_probs *= mtf.cast(
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
# Scale loss by group_size to keep loss constant across different group_sizes.
# true_group_size is number of tokens per group that are not masked out.
true_group_size = mtf.cast(
mtf.reduce_sum(importance, reduced_dim=group_size_dim),
dtype=raw_gates.dtype)
loss = (mtf.reduce_mean(
aggregated_group_probs * aggregated_group_probs * true_group_size) *
float(group_size_dim.size))
if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches
# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
importance)
mtf.scalar_summary(name + "/z_loss", z_loss)
loss += (hparams.moe_z_loss * z_loss)
################### Logging ###################
if train:
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
reduced_dim=group_size_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)
# Log for each token in the group how many experts it gets sent to.
num_experts_sent_per_token = (
mtf.reduce_sum(dispatch_tensor, output_shape=[group_size_dim]) *
float(experts_dim.size * expert_capacity_dim.size))
split_fractions = mtf.split(
num_experts_sent_per_token,
split_dim=group_size_dim,
num_or_size_splits=group_size_dim.size)
for fraction in split_fractions:
mtf.scalar_summary("group_token/" + fraction.name.replace(":", "/"),
mtf.reduce_sum(fraction))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
#################### Match the inputs dtype ###################
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
loss = mtf.cast(loss, inputs.dtype)
dispatch_tensor = mtf.cast(
mtf.cast(dispatch_tensor, tf.bool), combine_tensor.dtype)
return dispatch_tensor, combine_tensor, loss