in tensor2tensor/models/mtf_transformer.py [0:0]
def _sample(self, features, mesh):
hparams = self._hparams
(inputs_embedding_var,
targets_embedding_var,
softmax_var,
positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
if hparams.transformer_type == "encdec":
inputs = features["inputs"]
while len(inputs.shape.as_list()) > 2:
inputs = tf.squeeze(inputs, axis=2)
actual_batch_size = tf.shape(inputs)[0]
actual_length = tf.shape(inputs)[1]
inputs = tf.pad(
inputs, [[0, hparams.batch_size - actual_batch_size],
[0, hparams.max_length - actual_length]])
inputs = self._import_to_batch_by_length(
inputs, "inputs", mesh, hparams)
x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
mtf.reshape(positional_embedding_var,
mtf.Shape([self.length_dim, self.model_dim])))
encoder_attention_mask = (
mtf.layers.attention_mask_ignore_padding(
inputs, dtype=self.activation_dtype))
with tf.variable_scope("encoder"):
x = self._layer_stack(x,
hparams.encoder_layers,
self_attention_mask=encoder_attention_mask)
encoder_output = mtf.rename_dimension(
x, self.length_dim.name, self.memory_length_dim.name)
encdec_tensors = []
for layer_num, layer_type in enumerate(hparams.decoder_layers):
if layer_type == "enc_att":
with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
mesh, self.heads_dim, self.model_dim,
self.kv_dim, self.master_dtype, self.slice_dtype,
self.activation_dtype)
k = mtf.einsum(
[encoder_output, k_var],
mtf.Shape(
self.batch_dims + [self.heads_dim,
self.memory_length_dim, self.kv_dim]))
v = mtf.einsum(
[encoder_output, v_var],
mtf.Shape(
self.batch_dims + [self.heads_dim,
self.memory_length_dim, self.kv_dim]))
encdec_tensors.append((q_var, o_var, k, v))
else:
encdec_tensors.append(None)
partial_targets = None
elif hparams.transformer_type == "decoder":
encdec_tensors = None
encoder_output = None
encoder_attention_mask = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs", None)
if partial_targets is None:
partial_targets = features.get("targets", None)
if partial_targets is not None:
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int32(partial_targets)
partial_targets_batch = tf.shape(partial_targets)[0]
partial_targets_length = tf.shape(partial_targets)[1]
partial_targets = tf.pad(
partial_targets, [[0, hparams.batch_size - partial_targets_batch],
[0, hparams.max_length - partial_targets_length]])
partial_targets = self._import_to_batch_by_length(
partial_targets, "partial_targets", mesh, hparams)
else:
raise ValueError(
"hparams.model_type = %s not yet supported"
% hparams.transformer_type)
local_attention_window = mtf.Dimension(
"local_attention_window", hparams.local_attention_window_size)
if hparams.beam_size == 1:
ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
kv_shape = mtf.Shape(self.batch_dims +
[self.heads_dim,
self.memory_length_dim, self.kv_dim])
local_kv_shape = mtf.Shape(self.batch_dims +
[self.heads_dim,
local_attention_window, self.kv_dim])
else:
beam_dim = mtf.Dimension("beam", hparams.beam_size)
ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
kv_shape = mtf.Shape(self.batch_dims +
[beam_dim, self.heads_dim,
self.memory_length_dim, self.kv_dim])
local_kv_shape = mtf.Shape(self.batch_dims +
[beam_dim, self.heads_dim,
local_attention_window, self.kv_dim])
initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
initial_states = []
for layer in hparams.decoder_layers:
if layer == "att":
initial_states.extend(
[mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
elif layer == "local_att":
initial_states.extend(
[mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)
def logits_fn(step_num, ids, states):
"""Produce logits for this step, and new states."""
ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
x = (mtf.gather(targets_embedding_var, ids_this_step,
self.targets_vocab_dim) +
mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
with tf.variable_scope("decoder"):
x, new_states = self._layer_stack(
x,
hparams.decoder_layers,
encdec_attention_mask=encoder_attention_mask,
step_num=step_num,
encdec_tensors=encdec_tensors,
states=states)
logits = mtf.matmul(x, softmax_var)
return logits, new_states
if hparams.beam_size == 1:
temperature = (0.0 if hparams.sampling_method == "argmax"
else hparams.sampling_temp)
return mtf.beam_search.greedy_decode(
logits_fn,
initial_ids,
temperature=temperature,
initial_states=initial_states,
forced_ids=partial_targets,
use_tpu=hparams.use_tpu)
else:
if hparams.transformer_type == "encdec":
input_length = mtf.reduce_sum(
mtf.to_float(mtf.cast(inputs, tf.bool)),
reduced_dim=self.length_dim)
max_input_length = mtf.reduce_max(input_length)
decode_length = mtf.cast(
max_input_length * hparams.decode_length_multiplier
+ hparams.decode_length_constant, tf.int32)
else:
decode_length = None
beams, unused_scores = mtf.beam_search.beam_search(
logits_fn,
initial_ids,
hparams.alpha,
states=initial_states,
decode_length=decode_length,
use_tpu=hparams.use_tpu,
dtype=self.activation_dtype)
return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)