mesh_tensorflow/transformer/heterogeneous_moe.py [100:130]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    self._activation = activation

  def call(self, context, x, losses=None):
    """Call the layer."""
    if context.model.ensemble_dim:
      raise NotImplementedError("MoE not yet implemented with ensembles")

    has_length_dim = context.length_dim in x.shape.dims
    if not has_length_dim:
      x_shape = x.shape
      shape_with_length = mtf.Shape(
          x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
          + x_shape.dims[-1:])
      x = mtf.reshape(x, shape_with_length)

    # Extract the MoE output dimension
    if self._hparams.moe_output_dim is not None:
      output_dim = self._hparams.moe_output_dim
    else:
      output_dim = context.model.model_dim
    y, loss = transformer_moe_layer_v1(
        x,
        output_dim,
        self._hparams,
        context.train,
        context.variable_dtype,
        layout=context.model.layout,
        mesh_shape=context.model.mesh_shape,
        nonpadding=context.nonpadding,
        activation=self._activation,
        num_microbatches=context.num_microbatches,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/moe.py [99:129]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    self._activation = activation

  def call(self, context, x, losses=None):
    """Call the layer."""
    if context.model.ensemble_dim:
      raise NotImplementedError("MoE not yet implemented with ensembles")

    has_length_dim = context.length_dim in x.shape.dims
    if not has_length_dim:
      x_shape = x.shape
      shape_with_length = mtf.Shape(
          x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
          + x_shape.dims[-1:])
      x = mtf.reshape(x, shape_with_length)

    # Extract the MoE output dimension
    if self._hparams.moe_output_dim is not None:
      output_dim = self._hparams.moe_output_dim
    else:
      output_dim = context.model.model_dim
    y, loss = transformer_moe_layer_v1(
        x,
        output_dim,
        self._hparams,
        context.train,
        context.variable_dtype,
        layout=context.model.layout,
        mesh_shape=context.model.mesh_shape,
        nonpadding=context.nonpadding,
        activation=self._activation,
        num_microbatches=context.num_microbatches,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



