in tensor2tensor/models/research/aligned.py [0:0]
def body_sharded(self, sharded_features):
# Remove dropout if not training
hparams = self._hparams
dp = self._data_parallelism
x = dp(tf.squeeze, sharded_features["inputs"], 2)
def preprocess(x):
return dp(common_layers.layer_preprocess, x, hparams)
def postprocess(x, y):
return dp(common_layers.layer_postprocess, x, y, hparams)
x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
extra_loss = 0.0
ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")]
if hparams.mask_right:
def _bias(x):
return common_attention.attention_bias_lower_triangle(
common_layers.shape_list(x)[1])
bias = dp(_bias, x)
else:
bias = tf.zeros([1, 1, 1, 1])
batch_coordinate = dp(get_batch_coordinate, x)
layers = hparams.layers.strip(",").split(",")
for layer_num, layer_type in enumerate(layers):
with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
if _should_preprocess(layer_type):
x = preprocess(x)
if layer_type == "timing":
y = dp(common_attention.add_timing_signal_nd, x)
elif layer_type == "pos_emb":
y = dp(
common_attention.add_positional_embedding_nd,
x,
hparams.max_length,
name="pos_emb")
elif layer_type == "att":
y = dp(
common_attention.multihead_attention,
x,
None,
bias, # bias
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout)
elif layer_type == "att_grouped":
multiplicative_overhead = (
hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN
else hparams.multiplicative_overhead_eval)
y, loss = dp(
common_attention.grouped_attention_multihead,
x,
x,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
num_groups=hparams.attention_num_groups,
memory_target_density=hparams.memory_target_density,
multiplicative_overhead=multiplicative_overhead,
make_image_summary=hparams.attention_image_summary,
mask_right=hparams.mask_right,
)
extra_loss += tf.add_n(loss) / dp.n
elif layer_type == "att_memory_efficient":
assert hparams.layer_preprocess_sequence == "n"
y = dp(common_attention.multihead_self_attention_memory_efficient, x,
bias, hparams.num_heads)
elif layer_type == "att_local":
y = dp(
common_attention.multihead_attention,
x,
None,
None, # bias
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=("local_mask_right"
if hparams.mask_right else "local_unmasked"),
block_length=hparams.local_attention_window,
block_width=hparams.local_attention_window)
elif layer_type == "att_pseudolocal":
# This is an inefficient implementation of local attention, for the
# purpose of testing model quality.
def _pseudolocal_bias(x):
return common_attention.attention_bias_local(
common_layers.shape_list(x)[1], hparams.local_attention_window,
0 if hparams.mask_right else hparams.local_attention_window)
pseudolocal_bias = dp(_pseudolocal_bias, x)
y = dp(common_attention.multihead_attention, x, None,
pseudolocal_bias, hparams.attention_key_channels or
hparams.hidden_size, hparams.attention_value_channels or
hparams.hidden_size, hparams.hidden_size, hparams.num_heads,
hparams.attention_dropout)
elif layer_type == "att_local_expert":
y, loss = dp(
common_attention.local_expert_attention,
x,
k=hparams.attention_moe_k,
loss_coef=hparams.attention_load_balance,
attention_num_experts=hparams.attention_num_experts,
train=hparams.mode == ModeKeys.TRAIN,
batch_coordinate=batch_coordinate,
mask_right=hparams.mask_right,
split_batch=bool(hparams.attention_split_batch),
attention_kq_size=hparams.attention_kq_size,
attention_v_size=hparams.attention_v_size)
# TODO(avaswani, epot, noam): Do we need to divide by num shards ?
extra_loss += tf.add_n(loss) / dp.n
elif layer_type == "att_lsh":
if hparams.lsh_truncated:
attention_fn = common_attention.multihead_attention_sparse_truncated
else:
attention_fn = common_attention.multihead_attention_sparse_dot_prod
y, loss = dp(
attention_fn,
x,
None,
None, # Bias is computed inside
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
# Additional parameters
bi=[
common_attention.BatchInfo(
coordinates=batch_coordinate[i],
order=None, # No future mask
) for i in range(dp.n)
],
use_map_fn=False,
experts_params=dict(nb_hyperplanes=4,))
extra_loss += tf.add_n(loss) / dp.n
elif layer_type == "ffn":
y = dp(
expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes,
hparams.hidden_size),
dp(expert_utils.flatten_all_but_last, x))
y = dp(common_layers.reshape_like, y, x)
elif layer_type == "conv":
y = dp(
common_layers.conv1d,
x,
hparams.hidden_size,
hparams.kernel_height,
activation=tf.nn.relu,
padding="SAME",
)
else:
assert False, "unknown sublayer %s" % layer_type
if _should_postprocess(layer_type):
x = postprocess(x, y)
else:
x = y
x = preprocess(x)
decoder_output = dp(tf.expand_dims, x, 2)
return decoder_output, extra_loss