in mesh_tensorflow/transformer/moe.py [0:0]
def _switch_gating(
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
hparams, train, variable_dtype, importance=None, name="switch_gating",
num_microbatches=None, token_embeddings=None):
"""Compute Switch gating."""
# SELECT EXPERT
if train:
policy = hparams.moe_switch_policy_train
else:
policy = hparams.moe_switch_policy_eval
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)
# Input perturbations
if policy == "input_dropout":
gate_inputs = mtf.dropout(
gate_inputs,
is_training=train,
keep_prob=1.0 - hparams.moe_switch_dropout)
elif train and policy == "input_jitter":
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
hparams.moe_switch_jitter)
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)
gate_logits = mtf.layers.dense(
gate_inputs,
experts_dim,
use_bias=False,
expert_dims=outer_expert_dims,
variable_dtype=variable_dtype,
name=name)
if hparams.moe_use_second_place_expert_prob is not None and train:
gate_logits = _stochastically_use_non_top_expert(
gate_logits, experts_dim, hparams)
raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)
if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
elif policy == "sample":
expert_index = mtf.sample_with_temperature(
gate_logits, experts_dim, temperature=hparams.moe_switch_temperature)
expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
else:
raise ValueError("Unknown Switch gating policy %s" % policy)
expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)
# LOAD BALANCING LOSS
group_size_dim = inputs.shape[-2]
density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
if importance is not None:
expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
density_1_proxy *= mtf.cast(
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
loss = (
mtf.reduce_mean(density_1_proxy * density_1) *
float(experts_dim.size * experts_dim.size))
if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches
# Logging
if train:
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
reduced_dim=experts_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
total_routed = mtf.reduce_sum(mask_count_experts)
expert_fraction = mtf.to_float(mask_count_experts / total_routed)
split_fractions = mtf.split(
expert_fraction,
split_dim=experts_dim,
num_or_size_splits=experts_dim.size)
for fraction in split_fractions:
mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
mtf.reduce_mean(fraction))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
importance)
mtf.scalar_summary(name + "/z_loss", z_loss)
loss += (hparams.moe_z_loss * z_loss)
# COMPUTE ASSIGNMENT TO EXPERT
# Experts have a limited capacity, ensure we do not exceed it. Construct
# the batch indices, to each expert, with position_in_expert
position_in_expert = mtf.cumsum(
expert_mask, group_size_dim, exclusive=True) * expert_mask
position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype)
# Keep only tokens that fit within expert_capacity.
expert_capacity_float = float(expert_capacity_dim.size)
expert_mask *= mtf.cast(
mtf.less(position_in_expert, expert_capacity_float),
dtype=raw_gates.dtype)
expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim)
if train:
total_routed = mtf.reduce_sum(expert_mask_flat)
importance = mtf.cast(importance, dtype=total_routed.dtype)
mtf.scalar_summary("fraction_routed",
total_routed / mtf.reduce_sum(importance))
# Mask out the experts that have overflowed expert capacity. Sparsify the
# expert_gate.
expert_gate *= expert_mask_flat
combine_tensor = (
expert_gate * expert_mask_flat *
mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
mtf.one_hot(
mtf.to_int32(position_in_expert),
expert_capacity_dim,
dtype=raw_gates.dtype))
# Match the inputs dtype.
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
loss = mtf.cast(loss, inputs.dtype)
dispatch_tensor = mtf.cast(
mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)
return dispatch_tensor, combine_tensor, loss