def transformer_moe_layer_v1()

in mesh_tensorflow/transformer/moe.py [0:0]


def transformer_moe_layer_v1(
    inputs, output_dim, hparams, train, variable_dtype,
    layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
    num_microbatches=None, token_embeddings=None):
  """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.
    token_embeddings: a mtf.Tensor with shape
      [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
      that correspond to the inputs. These can optionally be used to make
      routing decisions.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
  # pylint: disable=line-too-long
  #
  # O outer_batch dimension can be used for expert replication, e.g.
  # outer_batch=4 for placing 128 experts on 512 cores with 4 replicas of each
  # expert.
  #
  # E.g. 16x16 basic example:
  #   moe_num_experts=512, num_groups=1024, batch=4096, length=256, d_model=1024
  # ---
  # Below ` indicates common way of splitting along mesh dimension.
  #
  # orig_inputs      OB`LM Tensor
  #                  Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
  #                  v (reshaped)
  # inputs           OG`SM
  #                  Shape[outer_batch=1, batch=1024, group=1024, d_model=1024]
  #
  # combine_tensor,
  # dispatch_tensor  OG`SEC
  #                  Shape[outer_batch=1, batch=1024, group=1024, expert_unsplit=512, expert_capacity=4]
  #
  # (dispatched inputs)
  # expert_inputs    OEG`CM
  #                  Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
  #                  v (re-split via ReshapeOperation)
  #                  OE`GCM
  #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
  #
  # (hidden representation)
  # h                OE`GCH
  #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, expert_hidden=8192]
  #
  # expert_output    OE`GCM
  #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
  #                  v (re-split via ReshapeOperation)
  #                  OEG`CM
  #                  Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
  #
  # (combined expert_output)
  # output           OG`SM
  #                  Shape[outer_batch=1, batch=1024, group=1024, d_model=1024
  #                  v (reshape)
  #                  OB`LM
  #                  Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
  #
  # pylint: enable=line-too-long
  orig_inputs = inputs
  hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
  experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups is a multiple of the mesh dimension
  # over which those groups are split.
  batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                      orig_inputs.shape.dims[-1])
  # Hack: we assume that
  #   "outer_batch" == replication of experts
  #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
  #
  # We then reqire num_groups to be a multiple of mesh_dim_size.
  if orig_inputs.shape.dims[0].name == "outer_batch":
    outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
  else:
    outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                       orig_inputs.shape.dims[0])

  # Number of MoE inputs (total number of position across batch_and_length_dims
  # per replica.
  n = 1
  for d in batch_and_length_dims:
    n *= d.size

  n = n // outer_batch_dim.size

  mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                  orig_batch_dim)
  num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
                                              mesh_dim_size)

  group_size_dim = mtf.Dimension("group", group_size)
  num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

  moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
  # OGSM Tensor
  inputs = mtf.reshape(inputs, moe_input_dims)

  # Token embeddings that can be optionally used in the router for determining
  # where to send tokens.
  if hparams.moe_word_embed_mode is not None:
    token_embeddings = mtf.cast(
        mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

  # Each sequence sends expert_capacity positions to each expert.
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(
      group_size_dim.size,
      int((group_size_dim.size * capacity_factor) / experts_dim.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  tf.logging.info("expert_capacity: %d" % expert_capacity)
  expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
  experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
  batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
  if nonpadding is not None:
    nonpadding = mtf.zeros(
        inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
    nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
  if hparams.moe_gating == "top_2":
    # combine_tensor,
    # dispatch_tensor  OG`SEC Tensors
    # (G is generally split along mesh dim)
    dispatch_tensor, combine_tensor, loss = _top_2_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "top_n":
    dispatch_tensor, combine_tensor, loss = _top_n_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "switch":
    dispatch_tensor, combine_tensor, loss = _switch_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "ntlb":
    dispatch_tensor, combine_tensor, loss = _ntlb_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "switch_max":
    dispatch_tensor, combine_tensor, loss = _switch_max_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  elif hparams.moe_gating == "expert_selection":
    dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        group_size_dim=group_size_dim,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        name="expert_selection_gating",
        num_microbatches=num_microbatches,
        token_embeddings=token_embeddings)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                             mtf.Shape([
                                 outer_batch_dim, experts_dim_unsplit,
                                 num_groups_dim, expert_capacity_dim, input_dim
                             ]))

  # Extra reshape reduces communication cost for model-parallel versions.
  # For model-parallel versions, this reshape causes an mtf.slice and for non-
  # model-parallel versions, this has no effect.
  d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          d_model_split_dim
      ]))

  # Split over batch -> split over experts
  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          input_dim
      ]))

  # Now feed the expert inputs through the experts.
  h = mtf.layers.dense_product(
      expert_inputs,
      reduced_dims=expert_inputs.shape.dims[-1:],
      new_dims=[hidden_dim],
      expert_dims=[experts_dim],
      activation_functions=activation, use_bias=False,
      variable_dtype=variable_dtype, name="wi")

  if hparams.moe_dropout_rate != 0.0:
    h = mtf.dropout(h, is_training=train,
                    keep_prob=1.0 - hparams.moe_dropout_rate)

  def _compute_output(hidden, layer_name):
    """Compute the output of the attention layer from the hidden vector."""
    expert_output = mtf.layers.dense(
        hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
        reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
        name=layer_name)

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    d_model_split_dim = mtf.Dimension(
        "d_model_split", expert_output.shape[-1].size)
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim, experts_dim_unsplit, num_groups_dim,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over experts -> split over batch
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output

  if hparams.moe_use_experts_attention:
    # We share k_h and v_h with no degradation in performance
    q_h, k_h = h, h
    outputs = []
    q = _compute_output(q_h, layer_name="q_wo")
    k = _compute_output(k_h, layer_name="k_wo")
    outputs.append(q)
    outputs.append(k)
    return outputs, loss * hparams.moe_loss_coef
  else:
    output = _compute_output(h, layer_name="wo")
    return output, loss * hparams.moe_loss_coef