mesh_tensorflow/transformer/heterogeneous_moe.py [437:457]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  expert_output = mtf.reshape(
      expert_output,
      mtf.Shape([
          outer_batch_dim, experts_dim_unsplit, num_groups_dim,
          expert_capacity_dim, d_model_split_dim
      ]))

  # Split over experts -> split over batch
  expert_output = mtf.reshape(
      expert_output,
      mtf.Shape([
          outer_batch_dim,
          experts_dim_unsplit,
          num_groups_dim,
          expert_capacity_dim,
          output_dim,
      ]))
  moe_output_dims = moe_input_dims[:-1] + [output_dim]
  output = mtf.einsum([expert_output, combine_tensor],
                      mtf.Shape(moe_output_dims))
  output = mtf.reshape(output, batch_and_length_dims + [output_dim])
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/moe.py [531:551]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim, experts_dim_unsplit, num_groups_dim,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over experts -> split over batch
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



