def _expert_selection_gating()

in mesh_tensorflow/transformer/moe.py [0:0]


def _expert_selection_gating(
    inputs, outer_expert_dims, experts_dim, group_size_dim,
    expert_capacity_dim, hparams, train, variable_dtype, importance=None,
    name="expert_selection_gating", num_microbatches=None,
    normalize_by_num_experts_routed=True, token_embeddings=None):
  """Compute gating where each expert chooses what tokens it wants."""
  # Select the randomization policy.
  if train:
    policy = hparams.moe_switch_policy_train
  else:
    policy = hparams.moe_switch_policy_eval

  # The internals of this function run in float32 otherwise instabilities
  # can occur.
  gate_inputs = mtf.to_float(inputs)

  # Input perturbations for exploration.
  if policy == "input_dropout":
    gate_inputs = mtf.dropout(gate_inputs, is_training=train,
                              keep_prob=1.0 - hparams.moe_switch_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_switch_jitter)

  if hparams.moe_word_embed_mode is not None:
    gate_inputs = _add_token_emb_to_gate_inputs(
        gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

  # Compute expert logits for each token.
  # gate_logits shape: [outer_batch, batch, group, expert_unsplit]
  gate_logits = mtf.layers.dense(
      gate_inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)

  # Set tokens to -inf before softmax if importance is zero as softmax is
  # normalized over all tokens in the group.
  if importance is not None:
    gate_logits += mtf.cast(
        mtf.equal(importance, 0.0), dtype=gate_logits.dtype) * -1e9
  raw_gates = mtf.softmax(gate_logits, reduced_dim=group_size_dim)

  # expert_gate_probs shape:
  # [outer_batch, batch, expert_unsplit, expert_capacity]
  # expert_gate_indices shape:
  # [outer_batch, batch, expert_unsplit, expert_capacity]
  expert_gate_probs, expert_gate_indices = mtf.top_k(
      raw_gates, reduced_dim=group_size_dim, k_dim=expert_capacity_dim)

  # dispatch_tensor shape:
  # [outer_batch, batch, expert_unsplit, expert_capacity, group]
  dispatch_tensor = mtf.one_hot(
      expert_gate_indices, group_size_dim, dtype=raw_gates.dtype)

  # combine_tensor shape:
  # [outer_batch, batch, expert_unsplit, expert_capacity, group]
  combine_tensor = dispatch_tensor * expert_gate_probs

  # Tokens will be aggregated across many experts and will not
  # be normalized. This could be an issue, so might want to normalize by the
  # number of experts each token is sent to.
  if normalize_by_num_experts_routed:
    num_experts_routed = mtf.reduce_sum(
        dispatch_tensor,
        output_shape=(dispatch_tensor.shape[:2] + [group_size_dim]))
    combine_tensor /= mtf.maximum(num_experts_routed, 1.0)

  ################### Compute the load balancing loss ###################
  # Push `aggregated_group_probs` of size `group` (which sums to num_experts)
  # to be uniform.
  # aggregated_group_probs shape: [outer_batch, batch, group]
  # importance shape: [outer_batch, batch, group]
  aggregated_group_probs = mtf.reduce_mean(raw_gates, reduced_dim=experts_dim)
  if importance is not None:
    aggregated_group_probs *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)

  # Scale loss by group_size to keep loss constant across different group_sizes.
  # true_group_size is number of tokens per group that are not masked out.
  true_group_size = mtf.cast(
      mtf.reduce_sum(importance, reduced_dim=group_size_dim),
      dtype=raw_gates.dtype)
  loss = (mtf.reduce_mean(
      aggregated_group_probs * aggregated_group_probs * true_group_size) *
          float(group_size_dim.size))

  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Add in the z_loss for router.
  if train and hparams.moe_z_loss is not None:
    tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
    z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
                            importance)
    mtf.scalar_summary(name + "/z_loss", z_loss)
    loss += (hparams.moe_z_loss * z_loss)

  ################### Logging ###################
  if train:
    entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
                             reduced_dim=group_size_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    # Log for each token in the group how many experts it gets sent to.
    num_experts_sent_per_token = (
        mtf.reduce_sum(dispatch_tensor, output_shape=[group_size_dim]) *
        float(experts_dim.size * expert_capacity_dim.size))
    split_fractions = mtf.split(
        num_experts_sent_per_token,
        split_dim=group_size_dim,
        num_or_size_splits=group_size_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("group_token/" + fraction.name.replace(":", "/"),
                         mtf.reduce_sum(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  #################### Match the inputs dtype ###################
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(dispatch_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss