in tensor2tensor/models/mtf_transformer.py [0:0]
def _layer_stack(self,
x,
layers,
encoder_output=None,
self_attention_mask=None,
encdec_attention_mask=None,
losses=None,
step_num=None,
encdec_tensors=None,
states=None):
"""Encoder or decoder stack.
Args:
x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
layers: an list of strings
encoder_output: an optional mtf.Tensor with shape
[<batch_dims>, encoder_length_dim, model_dim]
self_attention_mask: an optional mtf.Tensor with shape
[batch, length_dim, memory_length_dim] containing values 0 or -inf.
encdec_attention_mask: an optional mtf.Tensor with shape
[batch, length_dim, encoder_length_dim] containing values 0 or -inf.
losses: a list to be appended-to
step_num: an optional mtf integer Scalar (used in incrmenental mode)
encdec_tensors: an optional list of num_layers tuples, each of the form
(q_var, o_var, k, v), (used in incremental mode)
states: an optional list of Tensors (used in incremental mode)
Returns:
a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
Raises:
ValueError: if hparams make no sense
"""
hparams = self._hparams
is_incremental = (step_num is not None)
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
def layer_prepostprocess_dropout(x):
if is_incremental:
return x
return mtf.dropout(
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
num_layers = len(layers)
num_layer_norms = num_layers + 1
layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
layer_norm_combined_var = mtf.get_variable(
x.mesh,
"layer_norm_scale",
mtf.Shape([layer_norms_dim, self.model_dim]),
initializer=tf.ones_initializer(),
activation_dtype=x.dtype)
layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
def normalize(x):
scale = layer_norm_vars.pop(0)
variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale
if is_incremental:
states = list(states)
new_states = []
tf.logging.info("states = %s" % (states,))
for lnum, layer_type in enumerate(layers):
with tf.variable_scope("%s_%d" % (layer_type, lnum)):
if layer_type == "att":
# Self attention layer
if is_incremental:
y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
normalize(x),
prev_k=states.pop(0),
prev_v=states.pop(0),
step_num=step_num,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="att")
new_states.append(new_k)
new_states.append(new_v)
x += y
else:
x += layer_prepostprocess_dropout(
mtf.layers.multihead_attention(
normalize(x), None,
self_attention_mask, self.kv_dim, self.heads_dim,
is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="att"))
elif layer_type == "enc_att":
# Encoder-Decoder attention layer
if is_incremental:
# Encoder-Decoder attention layer
q_var, o_var, k, v = encdec_tensors[lnum]
x += mtf.layers.multihead_encdec_attention_incremental(
normalize(x),
q_var, o_var, k, v,
encdec_attention_mask,
name="enc_att")
else:
x += layer_prepostprocess_dropout(
mtf.layers.multihead_attention(
normalize(x), encoder_output,
encdec_attention_mask, self.kv_dim, self.heads_dim,
is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="enc_att"))
elif layer_type == "local_att":
if is_incremental:
y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
normalize(x),
prev_k=states.pop(0),
prev_v=states.pop(0),
step_num=step_num,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="local_att")
new_states.append(new_k)
new_states.append(new_v)
x += y
else:
x += layer_prepostprocess_dropout(
mtf.layers.masked_local_attention_1d(
normalize(x),
self.kv_dim, self.heads_dim, is_training,
window_size=hparams.local_attention_window_size,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
length_per_split=mtf.tensor_dim_to_size_per_split(
hparams.layout, hparams.mesh_shape,
self.max_length_dim),
name="local_att"))
elif layer_type == "compressed_att":
if is_incremental:
raise ValueError("compressed_att incremental not implemented")
else:
x += layer_prepostprocess_dropout(
mtf.layers.multihead_self_attention_memory_compressed(
normalize(x),
mask_right=True,
compression_factor=hparams.compression_factor,
kv_channels=self.kv_dim,
heads=self.heads_dim,
is_training=is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="compressed_att"))
else:
if is_incremental:
# insert length dimension.
x_shape = x.shape
shape_with_length = mtf.Shape(
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
+ x_shape.dims[-1:])
x = mtf.reshape(x, shape_with_length)
# ffn layer
x += layer_prepostprocess_dropout(
self._feedforward_layer(normalize(x), layer_type, losses=losses))
if is_incremental:
# remove length dimension
x = mtf.reshape(x, x_shape)
x = layer_prepostprocess_dropout(normalize(x))
assert not layer_norm_vars
if is_incremental:
return x, new_states
else:
return x