in tensor2tensor/models/evolved_transformer.py [0:0]
def evolved_transformer_decoder(decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=None,
decode_loop_step=None,
name="decoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None):
"""Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.
Args:
decoder_input: a Tensor.
encoder_output: a Tensor.
decoder_self_attention_bias: bias Tensor for self-attention (see
common_attention.attention_bias()).
encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
(see common_attention.attention_bias()).
hparams: hyperparameters for model.
cache: dict, containing tensors which are the results of previous
layers, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop. Only used
for inference on TPU.
name: a string.
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This is used to mask out
padding in convolutional layers. We generally only need this mask for
"packed" datasets, because for ordinary datasets, no padding is ever
followed by nonpadding.
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: Not supported.
Returns:
Decoder output tensor.
"""
del losses
num_trainable_top_decoder_layers = hparams.get(
"num_trainable_top_decoder_layers", -1) # -1 means train all weights.
if num_trainable_top_decoder_layers >= 0:
encoder_output = tf.stop_gradient(encoder_output)
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
hidden_state = decoder_input
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
for layer in range(num_layers):
if num_trainable_top_decoder_layers == num_layers - layer:
hidden_state = tf.stop_gradient(hidden_state)
layer_name = "layer_%d" % layer
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = layer_cache[
_SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None
left_state = common_attention.multihead_attention(
hidden_state,
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
_capped_double_heads(hparams.num_heads),
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
decode_loop_step=decode_loop_step,
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
if encoder_output is not None:
with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME):
attention_cache = (
layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME]
if layer_cache is not None else None)
right_state = common_attention.multihead_attention(
hidden_state,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.nn.dropout(
right_state, 1 - hparams.layer_prepostprocess_dropout)
hidden_state = residual_state + left_state + right_state
else:
hidden_state = common_layers.layer_postprocess(
residual_state, left_state, hparams)
with tf.variable_scope(_CONV_BRANCHES_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if nonpadding is not None:
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
hidden_state *= mask
if layer_cache:
if decode_loop_step is None:
hidden_state = layer_cache[
_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat(
[
layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME],
hidden_state
],
axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :]
left_state = hidden_state
right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING -
_DECODER_RIGHT_CONV_PADDING:, :]
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
tmp = tf.transpose(
layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2])
tmp = tf.expand_dims(tmp, axis=1)
tmp = inplace_ops.alias_inplace_update(
tmp,
decode_loop_step * tf.shape(hidden_state)[1] +
_DECODER_LEFT_CONV_PADDING,
tf.transpose(hidden_state, perm=[1, 0, 2]))
tmp = tf.squeeze(tmp, axis=1)
hidden_state = layer_cache[
_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose(
tmp, perm=[1, 0, 2])
batch_size = hidden_state.shape.as_list()[0]
left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
batch_size, _DECODER_LEFT_CONV_PADDING + 1,
hparams.hidden_size
])
right_state = tf.slice(hidden_state, [
0, decode_loop_step + _DECODER_LEFT_CONV_PADDING -
_DECODER_RIGHT_CONV_PADDING, 0
], [
batch_size, _DECODER_RIGHT_CONV_PADDING + 1,
hparams.hidden_size
])
else: # No caching.
left_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]])
right_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]])
left_output_dim = int(hparams.hidden_size * 2)
separable_conv_11x1 = tf.layers.SeparableConv1D(
left_output_dim,
11,
padding="VALID",
name="separable_conv11x1",
activation=tf.nn.relu)
left_state = separable_conv_11x1.apply(left_state)
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_output_dim = int(hparams.hidden_size / 2)
separable_conv_7x1_1 = tf.layers.SeparableConv1D(
right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1")
right_state = separable_conv_7x1_1.apply(right_state)
right_state = tf.nn.dropout(right_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.pad(
right_state,
[[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
constant_values=0)
hidden_state = left_state + right_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if nonpadding is not None:
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2])
hidden_state *= mask
if layer_cache:
if decode_loop_step is None:
hidden_state = layer_cache[
_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat(
[
layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME],
hidden_state
],
axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :]
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
tmp = tf.transpose(
layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2])
tmp = tf.expand_dims(tmp, axis=1)
tmp = inplace_ops.alias_inplace_update(
tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) *
tf.shape(hidden_state)[1],
tf.transpose(hidden_state, perm=[1, 0, 2]))
tmp = tf.squeeze(tmp, axis=1)
hidden_state = layer_cache[
_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose(
tmp, perm=[1, 0, 2])
batch_size = hidden_state.shape.as_list()[0]
hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
batch_size, _DECODER_FINAL_CONV_PADDING + 1,
hparams.hidden_size * 2
])
else:
hidden_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]])
separable_conv_7x1_2 = tf.layers.SeparableConv1D(
hparams.hidden_size,
7,
padding="VALID",
name="separable_conv_7x1_2")
hidden_state = separable_conv_7x1_2.apply(hidden_state)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope(_VANILLA_ATTENTION_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = layer_cache[
_VANILLA_ATTENTION_NAME] if layer_cache is not None else None
hidden_state = common_attention.multihead_attention(
hidden_state,
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
decode_loop_step=decode_loop_step,
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
if encoder_output is not None:
with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = (
layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME]
if layer_cache is not None else None)
hidden_state = common_attention.multihead_attention(
hidden_state,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("dense_layers"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = tf.layers.dense(
hidden_state,
int(hparams.hidden_size * 4),
activation=tf.nn.swish)
hidden_state = tf.nn.dropout(hidden_state,
1 - hparams.layer_prepostprocess_dropout)
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
decoder_output = common_layers.layer_preprocess(hidden_state, hparams)
if num_trainable_top_decoder_layers == 0:
decoder_output = tf.stop_gradient(decoder_output)
return decoder_output