def _layer_stack()

in tensor2tensor/models/mtf_transformer.py [0:0]


  def _layer_stack(self,
                   x,
                   layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None,
                   step_num=None,
                   encdec_tensors=None,
                   states=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    is_incremental = (step_num is not None)
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    def layer_prepostprocess_dropout(x):
      if is_incremental:
        return x
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layers = len(layers)
    num_layer_norms = num_layers + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    if is_incremental:
      states = list(states)
      new_states = []
    tf.logging.info("states = %s" % (states,))

    for lnum, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, lnum)):
        if layer_type == "att":
          # Self attention layer
          if is_incremental:
            y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), None,
                    self_attention_mask, self.kv_dim, self.heads_dim,
                    is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att"))
        elif layer_type == "enc_att":
          # Encoder-Decoder attention layer
          if is_incremental:
            # Encoder-Decoder attention layer
            q_var, o_var, k, v = encdec_tensors[lnum]
            x += mtf.layers.multihead_encdec_attention_incremental(
                normalize(x),
                q_var, o_var, k, v,
                encdec_attention_mask,
                name="enc_att")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), encoder_output,
                    encdec_attention_mask, self.kv_dim, self.heads_dim,
                    is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="enc_att"))
        elif layer_type == "local_att":
          if is_incremental:
            y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="local_att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    normalize(x),
                    self.kv_dim, self.heads_dim, is_training,
                    window_size=hparams.local_attention_window_size,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    length_per_split=mtf.tensor_dim_to_size_per_split(
                        hparams.layout, hparams.mesh_shape,
                        self.max_length_dim),
                    name="local_att"))
        elif layer_type == "compressed_att":
          if is_incremental:
            raise ValueError("compressed_att incremental not implemented")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_self_attention_memory_compressed(
                    normalize(x),
                    mask_right=True,
                    compression_factor=hparams.compression_factor,
                    kv_channels=self.kv_dim,
                    heads=self.heads_dim,
                    is_training=is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="compressed_att"))
        else:
          if is_incremental:
            # insert length dimension.
            x_shape = x.shape
            shape_with_length = mtf.Shape(
                x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
                + x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
          # ffn layer
          x += layer_prepostprocess_dropout(
              self._feedforward_layer(normalize(x), layer_type, losses=losses))
          if is_incremental:
            # remove length dimension
            x = mtf.reshape(x, x_shape)

    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    if is_incremental:
      return x, new_states
    else:
      return x