in tensor2tensor/models/research/attention_lm_moe.py [0:0]
def body_sharded(self, sharded_features):
# Remove dropout if not training
hparams = self._hparams
dp = self._data_parallelism
if hparams.use_inputs:
decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
decoder_self_attention_bias = None
else:
targets = sharded_features["targets"]
targets = dp(tf.squeeze, targets, 2)
(decoder_input, decoder_self_attention_bias, pad_remover) = dp(
attention_lm_moe_prepare_decoder, targets, hparams)
def preprocess(x):
return dp(common_layers.layer_preprocess, x, hparams)
def postprocess(x, y):
return dp(common_layers.layer_postprocess, x, y, hparams)
x = dp(tf.nn.dropout, decoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
extra_loss = 0.0
if not hparams.use_inputs:
# As preprocess and postprocess are called with batch of size one (all
# batches concatenated), we just make sure that batch_norm is not use (
# should not either way)
assert hparams.norm_type != "batch"
tf.logging.info("Applying Padding Remover for the attention experts")
dp_remove_pad = functools.partial(
dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode)
dp_restore_pad = functools.partial(
dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode)
else:
# Using identity function: No effect
dp_remove_pad = lambda x: x
dp_restore_pad = lambda x: x
if hparams.attention_exp_factor != 0:
tf.logging.info("Expand/compress tokens before sending them to experts")
dp_expand_bc = lambda x: dp( # pylint: disable=g-long-lambda
expand_batch_coordinates,
x,
hparams.attention_exp_factor)
dp_expand_x = lambda x: dp( # pylint: disable=g-long-lambda
common_attention.deconv_elems_1d,
x,
hparams.attention_exp_factor,
hparams.attention_exp_inputdim)
dp_compress_x = lambda x, l: dp( # pylint: disable=g-long-lambda
common_attention.conv_elems_1d,
x,
hparams.attention_exp_factor,
l)
else:
dp_expand_bc = lambda x: x
dp_expand_x = lambda x: x
dp_compress_x = lambda x, l: x
def print_shape(x, suffix, debug=False):
# To help debugging, print the input/output shapes at inference and eval
# Inference for long sequences can take a long time, so that's help to
# see the progression of the generation
if not debug and hparams.mode == ModeKeys.TRAIN:
return x
return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))
with tf.name_scope("batch_coordinate_preprocess"):
batch_coordinate = dp(get_batch_coordinate, x)
batch_coordinate = dp_remove_pad(batch_coordinate)
batch_coordinate = dp_expand_bc(batch_coordinate)
batch_order = dp(get_batch_coordinate, x, axis=-1)
batch_order = dp_remove_pad(batch_order)
batch_order = dp_expand_bc(batch_order)
x = dp(print_shape, x, "in")
assert hparams.batch_size >= hparams.max_length
num_hidden_layers = (
len(hparams.attention_layers) or hparams.num_hidden_layers)
for layer in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
# Use the layer type defined in attention_layers
if hparams.attention_layers:
attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]]
else:
attention_type = hparams.attention_type
with tf.variable_scope(
"attention_{}".format(attention_type)):
if attention_type in [
AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL]:
attention_dot_type = (
"local_mask_right" if hparams.attention_local else
"dot_product")
if attention_type == AttentionType.MULTIHEAD_FULL:
attention_dot_type = "dot_product"
y = dp(
common_attention.multihead_attention,
preprocess(x),
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=attention_dot_type,
block_length=hparams.attention_block_length,
name="decoder_self_attention")
elif attention_type == AttentionType.SPARSE_MULTIHEAD:
x_in = preprocess(x)
x_in = dp_remove_pad(x_in)
y, loss_experts = dp(
common_attention.multihead_attention_sparse_dot_prod,
x_in,
None,
None, # Bias is computed inside
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
# Additional parameters
bi=[common_attention.BatchInfo(
coordinates=batch_coordinate[i],
order=batch_order[i], # No future mask
) for i in range(dp.n)],
use_map_fn=hparams.lsh_use_map_fn,
experts_params=dict(
nb_hyperplanes=hparams.lsh_num_hyperplanes,
),
)
y = dp_restore_pad(y)
# TODO(avaswani, epot, noam): Do we need to divide by num shards ?
extra_loss += tf.add_n(loss_experts) / dp.n
elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
x_in = preprocess(x)
y, loss_experts = dp(
common_attention.multihead_attention_sparse_truncated,
x_in,
None,
None, # Bias is computed inside
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
# Additional parameters
bi=[common_attention.BatchInfo(
coordinates=batch_coordinate[i],
order=batch_order[i], # No future mask
) for i in range(dp.n)],
mask_right=True,
experts_params=dict(
nb_hyperplanes=hparams.lsh_num_hyperplanes,
),
)
# TODO(avaswani, epot, noam): Do we need to divide by num shards ?
extra_loss += tf.add_n(loss_experts) / dp.n
elif attention_type == AttentionType.MEMORY_EFFICIENT:
assert hparams.layer_preprocess_sequence == "n"
y = dp(
common_attention.multihead_self_attention_memory_efficient,
x,
decoder_self_attention_bias,
hparams.num_heads,
name="decoder_self_attention")
elif attention_type == AttentionType.MULTIHEAD_REDUCED:
y = dp(
common_attention.multihead_self_attention_reduced,
preprocess(x),
factor=hparams.attention_red_factor,
reduction_type=hparams.attention_reduction_type,
nonlinearity=hparams.attention_nonlinearity,
multihead_params=dict(
total_key_depth=
hparams.attention_key_channels or hparams.hidden_size,
total_value_depth=
hparams.attention_value_channels or hparams.hidden_size,
num_heads=hparams.num_heads,
dropout_rate=hparams.attention_dropout,
))
elif attention_type == AttentionType.LOCAL_EXPERTS:
x_in = preprocess(x)
x_in = dp_remove_pad(x_in)
x_in = dp_expand_x(x_in)
y, loss = dp(
common_attention.local_expert_attention,
x_in,
k=hparams.attention_moe_k,
loss_coef=hparams.attention_load_balance,
attention_num_experts=hparams.attention_num_experts,
train=hparams.mode == ModeKeys.TRAIN,
batch_coordinate=batch_coordinate,
mask_right=not hparams.use_inputs,
split_batch=bool(hparams.attention_split_batch),
attention_num_head=hparams.attention_num_head,
attention_kq_size=hparams.attention_kq_size,
attention_v_size=hparams.attention_v_size)
y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
y = dp_restore_pad(y)
# TODO(avaswani, epot, noam): Do we need to divide by num shards ?
extra_loss += tf.add_n(loss) / dp.n
else:
raise ValueError("Only {} supported for now.".format(
AttentionType.get_choices()))
x = postprocess(x, y)
with tf.variable_scope("ffn"):
if hparams.memory_efficient_ffn:
assert hparams.layer_preprocess_sequence == "n"
y = dp(
common_layers.conv_hidden_relu_memory_efficient,
x,
hparams.filter_size)
else:
additional_conv_params = {}
if hparams.use_sepconv:
additional_conv_params = dict(
padding="LEFT",
# Parameters copied from the transformer model
kernel_size=(3, 1),
second_kernel_size=(31, 1),
)
y = dp(
common_layers.conv_hidden_relu,
preprocess(x),
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.relu_dropout,
**additional_conv_params
)
x = postprocess(x, y)
x = preprocess(x)
decoder_output = dp(tf.expand_dims, x, 2)
return decoder_output, extra_loss