def _top_n_gating()

in mesh_tensorflow/transformer/moe.py [0:0]


def _top_n_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, variable_dtype, importance=None, name="top_n_gating",
    num_microbatches=None, token_embeddings=None):
  """Compute generalization of top-2 gating for mixture-of-experts.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float
    hparams.moe_top_n_num_experts_per_token: an int

  Tensor shapes are largely the same as in top_2 gating, so see that docstring
  for more details.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    variable_dtype: a mtf.VariableDType
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]
    name: an optional string
    num_microbatches: number of microbatches.
    token_embeddings: an optional tensor with shape
      [<batch_dims>, group_size_dim, input_dim] that is the input
      word embeddings.

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """

  group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

  # The internals of this function run in float32.
  # bfloat16 seems to reduce quality.
  gate_inputs = mtf.to_float(inputs)

  if hparams.moe_word_embed_mode is not None:
    gate_inputs = _add_token_emb_to_gate_inputs(
        gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

  gate_logits = mtf.layers.dense(
      gate_inputs, experts_dim, use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, experts_dim)

  expert_capacity_f = float(expert_capacity_dim.size)

  # Used for aux loss.
  density_1_proxy = raw_gates
  if importance is not None:
    density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))

  # Loop over the get the top-n tokens and their masks.
  gates = []
  masks = []
  indexes = []
  # Tensor that contains all but the top-n highest experts for each token.
  gates_without_top_n = raw_gates
  gates_without_top_1 = None  # Used for second place loss
  for n in range(hparams.moe_top_n_num_experts_per_token):
    # [batch, group]
    gate_n, index_n = mtf.top_1(gates_without_top_n, experts_dim)
    # [batch, group, experts]
    mask_n = mtf.one_hot(index_n, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
      mask_n *= mtf.to_float(mtf.greater(importance, 0.0))
      gate_n *= mtf.to_float(mtf.greater(importance, 0.0))
    gates_without_top_n *= (1.0 - mask_n)
    # Used for second place loss.
    if n == 1:
      gates_without_top_1 = gates_without_top_n
    gates.append(gate_n)
    masks.append(mask_n)
    indexes.append(index_n)

  if len(gates) > 1:
    # All gates probs are normalized over the top-n tokens.
    denom = mtf.add_n(gates) + 1e-9
    gates = [gate / denom for gate in gates]

  # BALANCING LOSSES
  # shape = [batch, experts]
  # We want to equalize the fraction of the batch assigned to each expert.
  mask_1 = masks[0]  # Mask for top-1 token.
  density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
  # Something continuous that is correlated with what we want to equalize.
  density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim)
  loss = (mtf.reduce_mean(density_1_proxy * density_1)
          * float(experts_dim.size * experts_dim.size))
  # TODO(barretzoph): Add in options for aux losses for n > 2.
  if hparams.moe_use_second_place_loss:
    pass
    # Also add a loss to encourage all experts to be used equally also as the
    # second-place expert.  Experimentally, this seems to be a wash.
    # We want to equalize the fraction of the batch assigned to each expert:
    density_2 = mtf.reduce_mean(masks[2], reduced_dim=group_size_dim)
    # As a proxy for density_2, we renormalize the raw gates after the top one
    # has been removed.
    normalized = gates_without_top_1 / (
        mtf.reduce_sum(gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
    density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim)
    loss_2 = (mtf.reduce_mean(density_2_proxy * density_2)
              * float(experts_dim.size * experts_dim.size))
    loss += loss_2 * 0.5
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Add in the z_loss for router.
  if train and hparams.moe_z_loss is not None:
    tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
    z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
                            importance)
    mtf.scalar_summary(name + "/z_loss", z_loss)
    loss += (hparams.moe_z_loss * z_loss)

  # Depending on the policy in the hparams, we may drop out some of the
  # second-place experts.
  def _update_mask_based_on_gate_value(gate_n, mask_n):
    """Update the mask based in the policy and the threshold for n>1.

    Args:
      gate_n: normalized router probability for the nth highest expert.
      mask_n: boolean one-hot tensor that keeps track of the nth expert to
        send to each toke. This also masks away tokens that will not be routed.

    Returns:
      An altered mask_n that will mask out any top-n token that doesn't follow
      the second_policy method and threshold.
    """
    if train:
      policy = hparams.moe_second_policy_train
      threshold = hparams.moe_second_threshold_train
    else:
      policy = hparams.moe_second_policy_eval
      threshold = hparams.moe_second_threshold_eval
    if policy == "all":
      # Use nth-place experts for all examples.
      pass
    elif policy == "none":
      # Never use nth-place experts for all examples.
      mask_n = mtf.zeros_like(mask_n)
    elif policy == "threshold":
      # Use nth-place experts if gate_n > threshold.
      mask_n *= mtf.to_float(mtf.greater(gate_n, threshold))
    elif policy == "random":
      # Use nth-place experts with probablity min(1.0, gate_n / threshold).
      mask_n *= mtf.to_float(
          mtf.less(mtf.random_uniform(gate_n.mesh, gate_n.shape),
                   gate_n / max(threshold, 1e-9)))
    else:
      raise ValueError("Unknown policy %s" % policy)
    return mask_n

  # Now update masks for n>1 to reflect how these additional tokens should be
  # routed according to their corresponding policies.
  # Only update for n>1 as we always want to route the top-1 token.
  for i in range(1, len(masks)):
    masks[i] = _update_mask_based_on_gate_value(gates[i], masks[i])

  def _compute_top_n_mask(gate_n, mask_n, index_n, prev_mask_count):
    # This is the position within the expert's mini-batch for this sequence.
    position_in_expert_n = (
        mtf.cumsum(mask_n, group_size_dim, exclusive=True) + prev_mask_count)
    # Mask out tokens that should not be routed.
    position_in_expert_n *= mask_n
    # Remove the elements that don't fit. [batch, group, experts]
    mask_n *= mtf.to_float(mtf.less(position_in_expert_n, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert.
    mask_n_count = mtf.reduce_sum(mask_n, reduced_dim=group_size_dim)
    # Keep running sum of total tokens sent to each expert.
    prev_mask_count += mask_n_count

    # [batch, group] - mostly ones, but zeros where something didn't fit.
    mask_n_flat = mtf.reduce_sum(mask_n, reduced_dim=experts_dim)
    # Weight assigned to nth expert.  [batch, group]
    gate_n *= mask_n_flat
    # [batch, group]
    position_in_expert_n = mtf.reduce_sum(
        position_in_expert_n, reduced_dim=experts_dim)
    partial_combine_tensor = (
        gate_n * mask_n_flat
        * mtf.one_hot(index_n, experts_dim)
        * mtf.one_hot(mtf.to_int32(position_in_expert_n), expert_capacity_dim))
    return prev_mask_count, partial_combine_tensor

  # [batch, experts]
  # How many examples in this group go to each expert. This starts at zero.
  prev_mask_count = 0.0
  partial_combine_tensors = []
  for gate_n, mask_n, index_n in zip(gates, masks, indexes):
    prev_mask_count, partial_combine_tensor = _compute_top_n_mask(
        gate_n, mask_n, index_n, prev_mask_count)
    partial_combine_tensors.append(partial_combine_tensor)
  combine_tensor = mtf.add_n(partial_combine_tensors)

  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)

  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss