def evolved_transformer_decoder()

in tensor2tensor/models/evolved_transformer.py [0:0]


def evolved_transformer_decoder(decoder_input,
                                encoder_output,
                                decoder_self_attention_bias,
                                encoder_decoder_attention_bias,
                                hparams,
                                cache=None,
                                decode_loop_step=None,
                                name="decoder",
                                nonpadding=None,
                                save_weights_to=None,
                                make_image_summary=True,
                                losses=None):
  """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.

  Args:
    decoder_input: a Tensor.
    encoder_output: a Tensor.
    decoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias()).
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias()).
    hparams: hyperparameters for model.
    cache: dict, containing tensors which are the results of previous
      layers, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop. Only used
      for inference on TPU.
    name: a string.
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used to mask out
      padding in convolutional layers.  We generally only need this mask for
      "packed" datasets, because for ordinary datasets, no padding is ever
      followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: Not supported.

  Returns:
    Decoder output tensor.
  """
  del losses

  num_trainable_top_decoder_layers = hparams.get(
      "num_trainable_top_decoder_layers", -1)  # -1 means train all weights.

  if num_trainable_top_decoder_layers >= 0:
    encoder_output = tf.stop_gradient(encoder_output)

  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))

  with tf.variable_scope(name):
    hidden_state = decoder_input

    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
    for layer in range(num_layers):
      if num_trainable_top_decoder_layers == num_layers - layer:
        hidden_state = tf.stop_gradient(hidden_state)
      layer_name = "layer_%d" % layer
      layer_cache = cache[layer_name] if cache is not None else None
      with tf.variable_scope(layer_name):

        with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          attention_cache = layer_cache[
              _SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None
          left_state = common_attention.multihead_attention(
              hidden_state,
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              _capped_double_heads(hparams.num_heads),
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              heads_share_relative_embedding=(
                  hparams.heads_share_relative_embedding),
              add_relative_to_values=hparams.add_relative_to_values,
              save_weights_to=save_weights_to,
              cache=attention_cache,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims,
              max_length=hparams.get("max_length"),
              decode_loop_step=decode_loop_step,
              vars_3d=hparams.get("attention_variables_3d"),
              activation_dtype=hparams.get("activation_dtype", "float32"),
              weight_dtype=hparams.get("weight_dtype", "float32"))

        if encoder_output is not None:
          with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME):
            attention_cache = (
                layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME]
                if layer_cache is not None else None)
            right_state = common_attention.multihead_attention(
                hidden_state,
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                max_relative_position=hparams.max_relative_position,
                heads_share_relative_embedding=(
                    hparams.heads_share_relative_embedding),
                add_relative_to_values=hparams.add_relative_to_values,
                save_weights_to=save_weights_to,
                cache=attention_cache,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d"),
                activation_dtype=hparams.get("activation_dtype", "float32"),
                weight_dtype=hparams.get("weight_dtype", "float32"))

            left_state = tf.nn.dropout(left_state,
                                       1 - hparams.layer_prepostprocess_dropout)
            right_state = tf.nn.dropout(
                right_state, 1 - hparams.layer_prepostprocess_dropout)

            hidden_state = residual_state + left_state + right_state

        else:
          hidden_state = common_layers.layer_postprocess(
              residual_state, left_state, hparams)

        with tf.variable_scope(_CONV_BRANCHES_NAME):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          if nonpadding is not None:
            # Mask padding from conv layers.
            mask = tf.tile(
                tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
            hidden_state *= mask

          if layer_cache:
            if decode_loop_step is None:
              hidden_state = layer_cache[
                  _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat(
                      [
                          layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME],
                          hidden_state
                      ],
                      axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :]
              left_state = hidden_state
              right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING -
                                         _DECODER_RIGHT_CONV_PADDING:, :]

            else:
              # Inplace update is required for inference on TPU.
              # Inplace_ops only supports inplace_update on the first dimension.
              tmp = tf.transpose(
                  layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2])
              tmp = tf.expand_dims(tmp, axis=1)
              tmp = inplace_ops.alias_inplace_update(
                  tmp,
                  decode_loop_step * tf.shape(hidden_state)[1] +
                  _DECODER_LEFT_CONV_PADDING,
                  tf.transpose(hidden_state, perm=[1, 0, 2]))
              tmp = tf.squeeze(tmp, axis=1)
              hidden_state = layer_cache[
                  _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose(
                      tmp, perm=[1, 0, 2])

              batch_size = hidden_state.shape.as_list()[0]
              left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
                  batch_size, _DECODER_LEFT_CONV_PADDING + 1,
                  hparams.hidden_size
              ])
              right_state = tf.slice(hidden_state, [
                  0, decode_loop_step + _DECODER_LEFT_CONV_PADDING -
                  _DECODER_RIGHT_CONV_PADDING, 0
              ], [
                  batch_size, _DECODER_RIGHT_CONV_PADDING + 1,
                  hparams.hidden_size
              ])

          else:  # No caching.
            left_state = tf.pad(
                hidden_state,
                paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]])
            right_state = tf.pad(
                hidden_state,
                paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]])

          left_output_dim = int(hparams.hidden_size * 2)
          separable_conv_11x1 = tf.layers.SeparableConv1D(
              left_output_dim,
              11,
              padding="VALID",
              name="separable_conv11x1",
              activation=tf.nn.relu)
          left_state = separable_conv_11x1.apply(left_state)
          left_state = tf.nn.dropout(left_state,
                                     1 - hparams.layer_prepostprocess_dropout)

          right_output_dim = int(hparams.hidden_size / 2)
          separable_conv_7x1_1 = tf.layers.SeparableConv1D(
              right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1")
          right_state = separable_conv_7x1_1.apply(right_state)
          right_state = tf.nn.dropout(right_state,
                                      1 - hparams.layer_prepostprocess_dropout)
          right_state = tf.pad(
              right_state,
              [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
              constant_values=0)

          hidden_state = left_state + right_state

          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
          if nonpadding is not None:
            # Mask padding from conv layers.
            mask = tf.tile(
                tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2])
            hidden_state *= mask

          if layer_cache:
            if decode_loop_step is None:
              hidden_state = layer_cache[
                  _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat(
                      [
                          layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME],
                          hidden_state
                      ],
                      axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :]

            else:
              # Inplace update is required for inference on TPU.
              # Inplace_ops only supports inplace_update on the first dimension.
              tmp = tf.transpose(
                  layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2])
              tmp = tf.expand_dims(tmp, axis=1)
              tmp = inplace_ops.alias_inplace_update(
                  tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) *
                  tf.shape(hidden_state)[1],
                  tf.transpose(hidden_state, perm=[1, 0, 2]))
              tmp = tf.squeeze(tmp, axis=1)
              hidden_state = layer_cache[
                  _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose(
                      tmp, perm=[1, 0, 2])

              batch_size = hidden_state.shape.as_list()[0]
              hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
                  batch_size, _DECODER_FINAL_CONV_PADDING + 1,
                  hparams.hidden_size * 2
              ])
          else:
            hidden_state = tf.pad(
                hidden_state,
                paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]])

          separable_conv_7x1_2 = tf.layers.SeparableConv1D(
              hparams.hidden_size,
              7,
              padding="VALID",
              name="separable_conv_7x1_2")
          hidden_state = separable_conv_7x1_2.apply(hidden_state)

          hidden_state = common_layers.layer_postprocess(
              residual_state, hidden_state, hparams)

        with tf.variable_scope(_VANILLA_ATTENTION_NAME):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          attention_cache = layer_cache[
              _VANILLA_ATTENTION_NAME] if layer_cache is not None else None
          hidden_state = common_attention.multihead_attention(
              hidden_state,
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              heads_share_relative_embedding=(
                  hparams.heads_share_relative_embedding),
              add_relative_to_values=hparams.add_relative_to_values,
              save_weights_to=save_weights_to,
              cache=attention_cache,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims,
              max_length=hparams.get("max_length"),
              decode_loop_step=decode_loop_step,
              vars_3d=hparams.get("attention_variables_3d"),
              activation_dtype=hparams.get("activation_dtype", "float32"),
              weight_dtype=hparams.get("weight_dtype", "float32"))
          hidden_state = common_layers.layer_postprocess(
              residual_state, hidden_state, hparams)

        if encoder_output is not None:
          with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME):
            residual_state = hidden_state
            hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

            attention_cache = (
                layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME]
                if layer_cache is not None else None)
            hidden_state = common_attention.multihead_attention(
                hidden_state,
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                max_relative_position=hparams.max_relative_position,
                heads_share_relative_embedding=(
                    hparams.heads_share_relative_embedding),
                add_relative_to_values=hparams.add_relative_to_values,
                save_weights_to=save_weights_to,
                cache=attention_cache,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d"),
                activation_dtype=hparams.get("activation_dtype", "float32"),
                weight_dtype=hparams.get("weight_dtype", "float32"))
            hidden_state = common_layers.layer_postprocess(
                residual_state, hidden_state, hparams)

        with tf.variable_scope("dense_layers"):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          hidden_state = tf.layers.dense(
              hidden_state,
              int(hparams.hidden_size * 4),
              activation=tf.nn.swish)
          hidden_state = tf.nn.dropout(hidden_state,
                                       1 - hparams.layer_prepostprocess_dropout)

          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size)
          hidden_state = common_layers.layer_postprocess(
              residual_state, hidden_state, hparams)

    decoder_output = common_layers.layer_preprocess(hidden_state, hparams)
    if num_trainable_top_decoder_layers == 0:
      decoder_output = tf.stop_gradient(decoder_output)
    return decoder_output