in tensor2tensor/models/research/transformer_symshard.py [0:0]
def body(self, features):
hparams = self._hparams
ps_devices = self._ps_devices
single_device = (len(ps_devices) == 1)
assert hparams.num_model_shards % len(ps_devices) == 0
shards_per_device = hparams.num_model_shards // len(ps_devices)
model_devices = [ps_devices[i // shards_per_device]
for i in range(hparams.num_model_shards)]
print("model_devices = %s" % model_devices)
mp = expert_utils.Parallelism(model_devices, reuse=False)
targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
# squeeze out channels, heights
targets = tf.squeeze(features["targets_raw"], [2, 3])
targets_embedding_var = mp(
tf.get_variable, "embedding",
[[targets_vocab_size, hparams.hidden_size]] * mp.n,
initializer=tf.random_normal_initializer(
0.0, hparams.hidden_size**-0.5))
shifted_targets = common_layers.shift_right_2d(targets)
# Bypass the symbol modality and use a different embedding on each shard.
if single_device:
targets_embedding_var_combined = tf.concat(targets_embedding_var, 1)
decoder_input_combined = common_layers.embedding(
shifted_targets, targets_vocab_size,
hparams.hidden_size * mp.n,
multiplier=hparams.hidden_size**0.5,
embedding_var=targets_embedding_var_combined,
)
decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
else:
targets_embedding_var_combined = None
decoder_input = mp(
common_layers.embedding, shifted_targets, targets_vocab_size,
hparams.hidden_size,
multiplier=hparams.hidden_size**0.5,
embedding_var=targets_embedding_var,
)
decoder_self_attention_bias = mp(
common_attention.attention_bias_lower_triangle,
tf.shape(targets)[1])
if "targets_segmentation" in features:
# "Packed" dataset - keep the examples from seeing each other.
targets_segmentation = features["targets_segmentation"]
targets_position = features["targets_position"]
decoder_self_attention_bias = mp(
tf.add, decoder_self_attention_bias,
mp(common_attention.attention_bias_same_segment,
targets_segmentation, targets_segmentation))
decoder_input = mp(
common_attention.add_timing_signal_1d_given_position,
decoder_input, targets_position)
else:
targets_position = None
decoder_self_attention_bias = mp(
common_attention.attention_bias_lower_triangle,
tf.shape(targets)[1])
decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)
if self.has_input:
inputs = tf.squeeze(features["inputs_raw"], [2, 3])
inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size
# share everything for now
share_inputs_and_targets_embedding = True
if share_inputs_and_targets_embedding:
assert inputs_vocab_size == targets_vocab_size
inputs_embedding_var = targets_embedding_var
inputs_embedding_var_combined = targets_embedding_var_combined
if single_device:
encoder_input_combined = common_layers.embedding(
inputs, inputs_vocab_size,
hparams.hidden_size * mp.n,
multiplier=hparams.hidden_size**0.5,
embedding_var=inputs_embedding_var_combined,
)
encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
else:
encoder_input = mp(
common_layers.embedding, inputs, inputs_vocab_size,
hparams.hidden_size,
multiplier=hparams.hidden_size**0.5,
embedding_var=inputs_embedding_var,
)
if "inputs_segmentation" in features:
# "Packed" dataset - keep the examples from seeing each other.
inputs_segmentation = features["inputs_segmentation"]
inputs_position = features["inputs_position"]
encoder_self_attention_bias = mp(
common_attention.attention_bias_same_segment,
inputs_segmentation, inputs_segmentation)
encoder_decoder_attention_bias = mp(
common_attention.attention_bias_same_segment,
targets_segmentation, inputs_segmentation)
encoder_input = mp(
common_attention.add_timing_signal_1d_given_position,
encoder_input, inputs_position)
else:
encoder_padding = tf.to_float(tf.equal(inputs, 0))
ignore_padding = common_attention.attention_bias_ignore_padding(
encoder_padding)
encoder_self_attention_bias = ignore_padding
encoder_decoder_attention_bias = ignore_padding
inputs_position = None
encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input)
# encoder stack here
with tf.variable_scope("encoder"):
encoder_input = mp(
tf.nn.dropout, encoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
encoder_output = _layer_stack(
mp,
encoder_input,
encoder_self_attention_bias,
hparams.encoder_layers,
hparams)
else:
encoder_decoder_attention_bias = None
encoder_output = None
with tf.variable_scope("decoder"):
decoder_input = mp(
tf.nn.dropout, decoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
decoder_output = _layer_stack(
mp,
decoder_input,
decoder_self_attention_bias,
layers=hparams.decoder_layers,
hparams=hparams,
encoder_output=encoder_output,
encoder_decoder_attention_bias=encoder_decoder_attention_bias)
# Bypass the symbol modality and compute logits directly.
# We compute a different set of logits on each shard, and sum them.
# Share the weights with the target embedding.
output_var = targets_embedding_var
output_var_combined = targets_embedding_var_combined
if single_device:
decoder_output = tf.concat(decoder_output, 2)
logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]])
num, denom = common_layers.padded_cross_entropy(
logits, targets, hparams.label_smoothing)
training_loss = num / denom
else:
logits = mp(
tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n)
logits = expert_utils.all_reduce_ring(logits, mp)
# On each device, we compute the loss for a part of the batch.
# This is faster than computing the whole loss on one shard.
mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])
def _loss_for_shard(logits, targets, shard):
logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
return common_layers.padded_cross_entropy(
logits, targets, hparams.label_smoothing)
num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
training_loss = tf.add_n(num) / tf.add_n(denom)
logits = logits[0]
logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
# override training loss so that it is not computed externally.
losses = {"training": training_loss}
return logits, losses