mesh_tensorflow/transformer/funnel_transformer.py [253:312]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  def call_simple(self,
                  inputs,
                  targets,
                  compute_loss,
                  mode=tf.estimator.ModeKeys.TRAIN,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  encoder_sequence_id=None,
                  decoder_sequence_id=None,
                  decoder_subsequence_id=None,
                  encoder_position=None,
                  decoder_position=None,
                  num_microbatches=1):
    """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    This class inherits the trnasformer.Bitransformer with one difference. The
    encoder is Funnel Transformer. So the length dimension is reduced. The
    decoder needs to use the updated `encoder_sequence_id`.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor
      decoder_subsequence_id: an optional Tensor
      encoder_position: an optional Tensor
      decoder_position: an optional Tensor
      num_microbatches: integer - greater than one if the step has been
        serialized into multiple microbatches to save memory.

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
    # encoder_sequene_id and decoder_sequence_id are used to delineate packed
    # examples but are also necessary to indicate padding where sequence_id==0.
    # If they are absent, then we assume that padding is indicated by zeros in
    # the inputs/targets, and we make up sequence_id tensors to indicate this.
    if encoder_sequence_id is None:
      encoder_sequence_id = mtf.minimum(inputs, 1)
    if decoder_sequence_id is None:
      decoder_sequence_id = mtf.minimum(targets, 1)
    encoder_layer_outputs = []
    shared_params = self._shared_params(inputs.mesh, variable_dtype)
    encoder_output, encoder_loss = self.encoder.call_simple(
        inputs,
        None,
        compute_loss,
        mode=mode,
        variable_dtype=variable_dtype,
        sequence_id=encoder_sequence_id,
        position=encoder_position,
        shared_params=shared_params,
        layer_outputs=encoder_layer_outputs,
        num_microbatches=num_microbatches)
    encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mesh_tensorflow/transformer/transformer.py [1525:1580]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  def call_simple(self,
                  inputs,
                  targets,
                  compute_loss,
                  mode=tf.estimator.ModeKeys.TRAIN,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  encoder_sequence_id=None,
                  decoder_sequence_id=None,
                  decoder_subsequence_id=None,
                  encoder_position=None,
                  decoder_position=None,
                  num_microbatches=1):
    """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor
      decoder_subsequence_id: an optional Tensor
      encoder_position: an optional Tensor
      decoder_position: an optional Tensor
      num_microbatches: integer - greater than one if the step has been
        serialized into multiple microbatches to save memory.

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
    # encoder_sequene_id and decoder_sequence_id are used to delineate packed
    # examples but are also necessary to indicate padding where sequence_id==0.
    # If they are absent, then we assume that padding is indicated by zeros in
    # the inputs/targets, and we make up sequence_id tensors to indicate this.
    if encoder_sequence_id is None:
      encoder_sequence_id = mtf.minimum(inputs, 1)
    if decoder_sequence_id is None:
      decoder_sequence_id = mtf.minimum(targets, 1)
    encoder_layer_outputs = []
    shared_params = self._shared_params(inputs.mesh, variable_dtype)
    encoder_output, encoder_loss = self.encoder.call_simple(
        inputs,
        None,
        compute_loss,
        mode=mode,
        variable_dtype=variable_dtype,
        sequence_id=encoder_sequence_id,
        position=encoder_position,
        shared_params=shared_params,
        layer_outputs=encoder_layer_outputs,
        num_microbatches=num_microbatches)
    encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



