in mesh_tensorflow/transformer/moe.py [0:0]
def _ntlb_gating(inputs,
outer_expert_dims,
experts_dim,
expert_capacity_dim,
hparams,
train,
variable_dtype,
importance=None,
name="ntlb_gating",
num_microbatches=None,
token_embeddings=None):
"""Compute Switch gating with no-token-left behind (NTLB) behavior."""
# SELECT EXPERT
if train:
policy = hparams.moe_switch_policy_train
else:
policy = hparams.moe_switch_policy_eval
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)
# Input perturbations
if train and policy == "input_jitter":
gate_inputs = mtf.layers.multiplicative_jitter(
gate_inputs, hparams.moe_switch_jitter)
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)
gate_logits = mtf.layers.dense(
gate_inputs,
experts_dim,
use_bias=False,
expert_dims=outer_expert_dims,
variable_dtype=variable_dtype,
name=name)
if hparams.moe_use_second_place_expert_prob is not None and train:
gate_logits = _stochastically_use_non_top_expert(
gate_logits, experts_dim, hparams)
raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
raw_gates = mtf.to_float(raw_gates)
# Top-k operation
k_dim = mtf.Dimension("k", hparams.moe_ntlb_top_k)
expert_gate, expert_index = mtf.top_k(
raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
expert_mask = mtf.one_hot(expert_index, experts_dim)
# LOAD BALANCING LOSS
outer_batch_dim = inputs.shape[0]
batch_dim = inputs.shape[1]
group_size_dim = inputs.shape[-2]
density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
if importance is not None:
expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
density_1_proxy *= mtf.cast(
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
loss = (
mtf.reduce_mean(density_1_proxy * density_1) *
float(experts_dim.size * experts_dim.size))
if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches
# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
importance)
mtf.scalar_summary(name + "/z_loss", z_loss)
loss += (hparams.moe_z_loss * z_loss)
# Logging
if train:
entropy = mtf.reduce_sum(
-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
total_routed = mtf.reduce_sum(mask_count_experts)
expert_fraction = mtf.to_float(mask_count_experts / total_routed)
split_fractions = mtf.split(
expert_fraction,
split_dim=experts_dim,
num_or_size_splits=experts_dim.size)
for fraction in split_fractions:
mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
mtf.reduce_mean(fraction))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
# COMPUTE ASSIGNMENT TO EXPERT
# Iteratively route tokens (no-token-left-behind). The idea is to route as
# many tokens as possible to top-i before then trying top-(i+1).
top_k_masks = mtf.split(
expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
top_k_gates = mtf.split(
expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
top_k_indices = mtf.split(
expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)
# Tensors cumulative values over the iterative process.
combine_tensor = mtf.constant(
inputs.mesh,
value=0,
shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
cum_tokens = mtf.constant(
inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
tokens_left_to_route = mtf.constant(
inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])
expert_capacity_float = float(expert_capacity_dim.size)
for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
top_k_indices):
top_i_mask = mtf.reshape(
top_i_mask,
new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
# Operate only on the unrouted tokens.
top_i_mask *= tokens_left_to_route
# Record cumulative number of tokens to each expert across iterations.
cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
top_i_mask, group_size_dim)
expert_overflow = mtf.to_float(
mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
output_i_tokens = top_i_mask * expert_overflow
# Update the cumulative tokens routed to each expert.
cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
tokens_left_to_route -= (
mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))
# Combine-tensor for this iteration
output_i_tokens_flat = mtf.reduce_sum(
output_i_tokens, reduced_dim=experts_dim)
position_in_expert = cumulative_tokens_in_expert - 1
top_i_combine_tensor = (
top_i_gate * output_i_tokens_flat *
mtf.one_hot(top_i_index, experts_dim) *
mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
combine_tensor += top_i_combine_tensor
# Match the inputs dtype.
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
loss = mtf.cast(loss, inputs.dtype)
dispatch_tensor = mtf.cast(
mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)
return dispatch_tensor, combine_tensor, loss