def transformer_moe_layer_v1()

in mesh_tensorflow/transformer/heterogeneous_moe.py [0:0]


def transformer_moe_layer_v1(
    inputs, output_dim, hparams, train, variable_dtype,
    layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
    num_microbatches=None, token_embeddings=None, context=None):
  """Local heterogenous mixture of experts.

  See transformer_moe_layer_v1 in moe.py for a more detailed explanation for
  a generic moe layer.

  The heterogeneous mask outputted by generate_heterogeneous_expert_masks has
  dimension [maximum hidden size, maximum # layers, # experts] and its shape
  will overwrite the parameters moe_num_layers and moe_hidden_size in hparams.
  The layer-specific mask slice is applied at each expert layer to the
  activation which is [expert width, # experts]. If the heterogeneous_mask_info
  is None, there is no mask applied and the code is equivalent to the
  homogeneous case.


  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.
    token_embeddings: a mtf.Tensor with shape
      [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
      that correspond to the inputs. These can optionally be used to make
      routing decisions.
    context: a Context.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
  orig_inputs = inputs

  experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

  if hparams.moe_heterogeneous_mask_info is not None:
    tf.logging.info("moe_heterogeneous_mask_info: {}".format(
        hparams.moe_heterogeneous_mask_info))
    heterogeneous_mask = generate_heterogeneous_expert_masks(
        hparams.moe_heterogeneous_mask_info,
        hparams.moe_num_experts,
        experts_dim,
        mesh=inputs.mesh,
        expert_width=hparams.moe_hidden_size)
    # overwrite depth and width with the mask maximum dimension
    hparams.moe_num_layers = heterogeneous_mask.shape[1].size
    hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
  hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups is a multiple of the mesh dimension
  # over which those groups are split.
  batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                      orig_inputs.shape.dims[-1])
  # Hack: we assume that
  #   "outer_batch" == replication of experts
  #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
  #
  # We then reqire num_groups to be a multiple of mesh_dim_size.
  if orig_inputs.shape.dims[0].name == "outer_batch":
    outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
  else:
    outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                       orig_inputs.shape.dims[0])

  # Number of MoE inputs (total number of position across batch_and_length_dims
  # per replica.
  n = 1
  for d in batch_and_length_dims:
    n *= d.size

  n = n // outer_batch_dim.size

  mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                  orig_batch_dim)
  num_groups, group_size = moe._split_into_groups(  # pylint: disable=protected-access
      n, hparams.moe_group_size, mesh_dim_size)
  # TODO(barretzoph): implementation without pylint calls?

  group_size_dim = mtf.Dimension("group", group_size)
  num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

  moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
  # OGSM Tensor
  inputs = mtf.reshape(inputs, moe_input_dims)

  # Token embeddings that can be optionally used in the router for determining
  # where to send tokens.
  if hparams.moe_word_embed_mode is not None:
    token_embeddings = mtf.cast(
        mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

  # Each sequence sends expert_capacity positions to each expert.
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(
      group_size_dim.size,
      int((group_size_dim.size * capacity_factor) / experts_dim.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  tf.logging.info("expert_capacity: %d" % expert_capacity)
  expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
  experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
  batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
  if nonpadding is not None:
    nonpadding = mtf.zeros(
        inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
    nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
  if hparams.moe_gating == "top_2":
    # combine_tensor,
    # dispatch_tensor  OG`SEC Tensors
    # (G is generally split along mesh dim)
    dispatch_tensor, combine_tensor, loss = moe._top_2_gating(  # pylint: disable=protected-access
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "top_n":
    dispatch_tensor, combine_tensor, loss = moe._top_n_gating(  # pylint: disable=protected-access
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "switch":
    dispatch_tensor, combine_tensor, loss = moe._switch_gating(  # pylint: disable=protected-access
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "ntlb":
    dispatch_tensor, combine_tensor, loss = moe._ntlb_gating(  # pylint: disable=protected-access
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "switch_max":
    dispatch_tensor, combine_tensor, loss = moe._switch_max_gating(  # pylint: disable=protected-access
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "expert_selection":
    dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating(  # pylint: disable=protected-access
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        group_size_dim=group_size_dim,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        name="expert_selection_gating",
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                             mtf.Shape([
                                 outer_batch_dim, experts_dim_unsplit,
                                 num_groups_dim, expert_capacity_dim, input_dim
                             ]))

  # Extra reshape reduces communication cost for model-parallel versions.
  # For model-parallel versions, this reshape causes an mtf.slice and for non-
  # model-parallel versions, this has no effect.
  d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          d_model_split_dim
      ]))

  # Split over batch -> split over experts
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          input_dim
      ]))

  # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts]
  for layer_idx in range(hparams.moe_num_layers):
    with tf.variable_scope("expert_layer_{}".format(layer_idx)):
      res_h = 0.0
      if layer_idx > 0:
        res_h = expert_inputs
        expert_inputs = transformer.sublayer_rms_norm(
            expert_inputs, None, context)

      # Now feed the expert inputs through the experts.
      h = mtf.layers.dense_product(
          expert_inputs,
          reduced_dims=expert_inputs.shape.dims[-1:],
          new_dims=[hidden_dim],
          expert_dims=[experts_dim],
          activation_functions=activation, use_bias=False,
          variable_dtype=variable_dtype, name="wi")

      # apply dropout
      if hparams.moe_dropout_rate != 0.0:
        h = mtf.dropout(h, is_training=train,
                        keep_prob=1.0 - hparams.moe_dropout_rate)
      # only if heterogeneous
      if hparams.moe_heterogeneous_mask_info is not None:
        # Get mask for current layer by slicing heterogeneous mask
        heterogeneous_mask_slice = mtf.slice(
            heterogeneous_mask, layer_idx, 1, "num_expert_layers")

        # Get rid of the expert layers dimension.
        heterogeneous_mask_slice = mtf.reshape(
            heterogeneous_mask_slice,
            [heterogeneous_mask_slice.shape[0],
             heterogeneous_mask_slice.shape[-1]])
        h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
      expert_output = mtf.layers.dense(
          h, output_dim, expert_dims=[experts_dim], use_bias=False,
          reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
          name="wo")

      if layer_idx < (hparams.moe_num_layers - 1):
        expert_output = transformer.sublayer_dropout(
            expert_output, None, context)
      expert_output += res_h
      expert_inputs = expert_output

  # Extra reshape reduces communication cost for model-parallel versions.
  # For model-parallel versions, this reshape causes an mtf.slice and for non-
  # model-parallel versions, this has no effect.
  expert_output = mtf.reshape(
      expert_output,
      mtf.Shape([
          outer_batch_dim, experts_dim_unsplit, num_groups_dim,
          expert_capacity_dim, d_model_split_dim
      ]))

  # Split over experts -> split over batch
  expert_output = mtf.reshape(
      expert_output,
      mtf.Shape([
          outer_batch_dim,
          experts_dim_unsplit,
          num_groups_dim,
          expert_capacity_dim,
          output_dim,
      ]))
  moe_output_dims = moe_input_dims[:-1] + [output_dim]
  output = mtf.einsum([expert_output, combine_tensor],
                      mtf.Shape(moe_output_dims))
  output = mtf.reshape(output, batch_and_length_dims + [output_dim])
  return output, loss * hparams.moe_loss_coef