mesh_tensorflow/transformer/moe.py [938:963]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Add in the z_loss for router.
  if train and hparams.moe_z_loss is not None:
    tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
    z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
                            importance)
    mtf.scalar_summary(name + "/z_loss", z_loss)
    loss += (hparams.moe_z_loss * z_loss)

  # Logging
  if train:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/moe.py [1098:1123]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Add in the z_loss for router.
  if train and hparams.moe_z_loss is not None:
    tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
    z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
                            importance)
    mtf.scalar_summary(name + "/z_loss", z_loss)
    loss += (hparams.moe_z_loss * z_loss)

  # Logging
  if train:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



