mesh_tensorflow/transformer/heterogeneous_moe.py [254:284]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  group_size_dim = mtf.Dimension("group", group_size)
  num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

  moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
  # OGSM Tensor
  inputs = mtf.reshape(inputs, moe_input_dims)

  # Token embeddings that can be optionally used in the router for determining
  # where to send tokens.
  if hparams.moe_word_embed_mode is not None:
    token_embeddings = mtf.cast(
        mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

  # Each sequence sends expert_capacity positions to each expert.
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(
      group_size_dim.size,
      int((group_size_dim.size * capacity_factor) / experts_dim.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  tf.logging.info("expert_capacity: %d" % expert_capacity)
  expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
  experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
  batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
  if nonpadding is not None:
    nonpadding = mtf.zeros(
        inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
    nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
  if hparams.moe_gating == "top_2":
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/moe.py [371:401]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  group_size_dim = mtf.Dimension("group", group_size)
  num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

  moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
  # OGSM Tensor
  inputs = mtf.reshape(inputs, moe_input_dims)

  # Token embeddings that can be optionally used in the router for determining
  # where to send tokens.
  if hparams.moe_word_embed_mode is not None:
    token_embeddings = mtf.cast(
        mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

  # Each sequence sends expert_capacity positions to each expert.
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(
      group_size_dim.size,
      int((group_size_dim.size * capacity_factor) / experts_dim.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  tf.logging.info("expert_capacity: %d" % expert_capacity)
  expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
  experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
  batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
  if nonpadding is not None:
    nonpadding = mtf.zeros(
        inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
    nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
  if hparams.moe_gating == "top_2":
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



