def _switch_max_gating()

in mesh_tensorflow/transformer/moe.py [0:0]


def _switch_max_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, variable_dtype, importance=None, name="switch_max_gating",
    num_microbatches=None, token_embeddings=None):
  """Compute Switch gating."""
  # TODO(barretzoph,liamfedus): Refactor switch_max, switch and ntlb to limit
  # code resuse.
  # SELECT EXPERT
  if train:
    policy = hparams.moe_switch_policy_train
  else:
    policy = hparams.moe_switch_policy_eval

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  gate_inputs = mtf.to_float(inputs)

  # Input perturbations
  if policy == "input_dropout":
    gate_inputs = mtf.dropout(
        gate_inputs, is_training=train,
        keep_prob=1.0 - hparams.moe_switch_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_switch_jitter)

  if hparams.moe_word_embed_mode is not None:
    gate_inputs = _add_token_emb_to_gate_inputs(
        gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

  gate_logits = mtf.layers.dense(
      gate_inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)

  if hparams.moe_use_second_place_expert_prob is not None and train:
    gate_logits = _stochastically_use_non_top_expert(
        gate_logits, experts_dim, hparams)

  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
    expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
  elif policy == "sample":
    expert_index = mtf.sample_with_temperature(
        gate_logits, experts_dim, temperature=hparams.moe_switch_temperature)
    expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
  else:
    raise ValueError("Unknown Switch gating policy %s" % policy)

  expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)

  # LOAD BALANCING LOSS
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Add in the z_loss for router.
  if train and hparams.moe_z_loss is not None:
    tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
    z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
                            importance)
    mtf.scalar_summary(name + "/z_loss", z_loss)
    loss += (hparams.moe_z_loss * z_loss)

  # Logging
  if train:
    entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
                             reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)
    mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # Instead of doing the normal cumulative sum we want to take the top
  # `expert_capacity` tokens. If there are less than `expert_capacity_dim`
  # tokens getting routed to an expert then the combine_tensor will zero these
  # out
  # expert_mask shape: [outer_batch, batch, group_size, experts_unsplit]
  # expert_gate shape: [outer_batch, batch, group_size]
  expert_masked_probs = expert_mask * expert_gate
  expert_gate_probs, expert_gate_indices = mtf.top_k(
      expert_masked_probs, reduced_dim=group_size_dim,
      k_dim=expert_capacity_dim)
  dispatch_tensor = mtf.one_hot(
      expert_gate_indices, group_size_dim, dtype=raw_gates.dtype)
  combine_tensor = dispatch_tensor * expert_gate_probs

  if train:
    total_routed = mtf.reduce_sum(mtf.cast(mtf.greater(combine_tensor, 0.0),
                                           dtype=raw_gates.dtype))
    importance = mtf.cast(importance, dtype=total_routed.dtype)
    mtf.scalar_summary("fraction_routed",
                       total_routed / mtf.reduce_sum(importance))

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss