mesh_tensorflow/transformer/heterogeneous_moe.py [349:387]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        group_size_dim=group_size_dim,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        name="expert_selection_gating",
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                             mtf.Shape([
                                 outer_batch_dim, experts_dim_unsplit,
                                 num_groups_dim, expert_capacity_dim, input_dim
                             ]))

  # Extra reshape reduces communication cost for model-parallel versions.
  # For model-parallel versions, this reshape causes an mtf.slice and for non-
  # model-parallel versions, this has no effect.
  d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          d_model_split_dim
      ]))

  # Split over batch -> split over experts
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          input_dim
      ]))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/moe.py [466:504]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        group_size_dim=group_size_dim,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        name="expert_selection_gating",
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                             mtf.Shape([
                                 outer_batch_dim, experts_dim_unsplit,
                                 num_groups_dim, expert_capacity_dim, input_dim
                             ]))

  # Extra reshape reduces communication cost for model-parallel versions.
  # For model-parallel versions, this reshape causes an mtf.slice and for non-
  # model-parallel versions, this has no effect.
  d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          d_model_split_dim
      ]))

  # Split over batch -> split over experts
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          input_dim
      ]))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



