def transformer_moe_layer_v2()

in mesh_tensorflow/transformer/moe.py [0:0]


def transformer_moe_layer_v2(
    inputs, output_dim, hparams, train, variable_dtype,
    layout=None, mesh_shape=None, nonpadding=None, num_microbatches=None):
  """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional mtf.Tensor with shape [a, b, l]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    num_microbatches: number of microbatches.

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
  if nonpadding is not None:
    nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1],
                           dtype=inputs.dtype) + nonpadding
  insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
  if insert_outer_batch_dim:
    inputs = mtf.reshape(
        inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims)

  assert len(hparams.moe_num_experts) == 2
  a0, b1, l, m = inputs.shape.dims
  hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
  x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
  y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
  x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
  y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
  n = output_dim

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups (g.size) is a multiple of the mesh dimension
  # over which those groups are split.
  num_groups, group_size = _split_into_groups(
      b1.size * l.size, hparams.moe_group_size,
      mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1))
  g1 = mtf.Dimension(b1.name, num_groups)
  g = mtf.Dimension(b1.name + "_unsplit", g1.size)
  s = mtf.Dimension("group_size_x", group_size)

  # Each sequence sends (at most?) expert_capacity positions to each expert.
  # Static expert_capacity dimension is needed for expert batch sizes
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  c = mtf.Dimension("expert_capacity_x", expert_capacity)

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups (h.size) is a multiple of the mesh dimension
  # over which those groups are split.
  num_groups, group_size = _split_into_groups(
      a0.size * g.size * c.size,
      hparams.moe_group_size,
      mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0))
  t = mtf.Dimension("group_size_y", group_size)
  h0 = mtf.Dimension(a0.name, num_groups)
  h = mtf.Dimension(a0.name + "_unsplit", h0.size)

  expert_capacity = min(
      t.size,
      int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  d = mtf.Dimension("expert_capacity_y", expert_capacity)

  # First level of expert routing
  # Reshape the inner batch size to a multiple of group_dim g1 and
  # group_size_dim s.
  inputs = mtf.reshape(inputs, [a0, g1, s, m])
  if nonpadding is not None:
    nonpadding = mtf.reshape(nonpadding, [a0, g1, s])

  # Get the assignments for the first level.
  # dispatch_tensor_x has shape [a0, g1, s, x, c]
  if hparams.moe_gating == "top_2":
    dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=x,
        expert_capacity_dim=c,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        name="outer_gating",
        importance=nonpadding,
        num_microbatches=num_microbatches)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  # Now create expert_inputs based on the assignments.
  # put num_experts dimension first to make split easier in alltoall
  expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m])

  # we construct an "importance" Tensor for the inputs to the second-level
  # gating.  The importance of an input is 1.0 if it represents the
  # first-choice expert-group and 0.5 if it represents the second-choice expert
  # group.  This is used by the second-level gating.
  importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
  importance = 0.5 * (
      mtf.to_float(mtf.greater(importance, 0.5)) +
      mtf.to_float(mtf.greater(importance, 0.0)))

  # First level, all to all. Here we change the split dimension from g1 to x1.
  expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape(
      [x1, a0, g, c, m]))
  importance = mtf.reshape(importance, [x1, a0, g, c])

  # Second level of expert routing
  # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
  # and group_size_dim t.
  inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
  importance = mtf.reshape(importance, [x1, h0, t])

  # Get the assignments for the second level.
  # dispatch_tensor_y has shape [x1, h0, t, y, d]
  if hparams.moe_gating == "top_2":
    dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
        inputs=inputs_y,
        outer_expert_dims=[x1],
        experts_dim=y,
        expert_capacity_dim=d,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=importance,
        name="inner_gating",
        num_microbatches=num_microbatches)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  # Now create expert_inputs based on the assignments.
  # put num_experts dimension first to make split easier in alltoall
  expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m])

  # Second level, all to all. Here we change the split dimension from h0 to y0.
  expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape(
      [y0, x1, h, d, m]))

  hidden_output = mtf.layers.dense(
      expert_inputs_y, hidden_dim, expert_dims=[y0, x1],
      reduced_dims=expert_inputs_y.shape.dims[-1:],
      activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype,
      name="wi")
  expert_output = mtf.layers.dense(
      hidden_output, output_dim, expert_dims=[y0, x1],
      reduced_dims=hidden_output.shape.dims[-1:],
      use_bias=False, variable_dtype=variable_dtype,
      name="wo")

  # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
  # expert_output has shape [y0, x1, h, d, n]

  # alltoall
  expert_output = mtf.reshape(expert_output, mtf.Shape(
      [y, x1, h0, d, n]))

  # combine results from inner level
  output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

  # Reshape the combined tensor from inner level to now contain outer_batch_dim
  # a0 and group_dim g
  output = mtf.reshape(output_y, [x1, a0, g, c, n])

  # alltoall from expert_dim x to group_dim g1
  expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

  # combine results from outer level
  output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

  # Reshape the combined tensor to now contain inner_batch_dim
  # b1 and the original sequence length
  output = mtf.reshape(output_x, [a0, b1, l, n])
  if insert_outer_batch_dim:
    output = mtf.reshape(output, [b1, l, n])
  return output, (loss_outer + loss_inner) * hparams.moe_loss_coef