in tensor2tensor/models/transformer.py [0:0]
def _fast_decode(self,
features,
decode_length,
beam_size=1,
top_beams=1,
alpha=1.0,
preprocess_targets_method=None):
"""Fast decoding.
Implements both greedy and beam search decoding, uses beam search iff
beam_size > 1, otherwise beam search related arguments are ignored.
Args:
features: a map of string to model features.
decode_length: an integer. How many additional timesteps to decode.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
preprocess_targets_method: method used to preprocess targets. If None,
uses method "preprocess_targets" defined inside this method.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length]
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If there are multiple data shards.
"""
if self._num_datashards != 1:
raise NotImplementedError("Fast decoding only supports a single shard.")
dp = self._data_parallelism
hparams = self._hparams
target_modality = self._problem_hparams.modality["targets"]
target_vocab_size = self._problem_hparams.vocab_size["targets"]
if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
if "targets_segmentation" in features:
raise NotImplementedError(
"Decoding not supported on packed datasets "
" If you want to decode from a dataset, use the non-packed version"
" of the dataset when decoding.")
if self.has_input:
inputs_shape = common_layers.shape_list(features["inputs"])
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
decode_length = (
inputs_shape[1] + features.get("decode_length", decode_length))
batch_size = inputs_shape[0]
inputs = self._prepare_inputs_for_decode(features)
with tf.variable_scope("body"):
encoder_output, encoder_decoder_attention_bias = dp(
self.encode,
inputs,
features["target_space_id"],
hparams,
features=features)
encoder_output = encoder_output[0]
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
partial_targets = features.get("partial_targets")
else:
# The problem has no inputs.
encoder_output = None
encoder_decoder_attention_bias = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs")
if partial_targets is None:
partial_targets = features["targets"]
assert partial_targets is not None
if partial_targets is not None:
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int64(partial_targets)
partial_targets_shape = common_layers.shape_list(partial_targets)
partial_targets_length = partial_targets_shape[1]
decode_length = (
partial_targets_length + features.get("decode_length", decode_length))
batch_size = partial_targets_shape[0]
if hparams.pos == "timing":
positional_encoding = common_attention.get_timing_signal_1d(
decode_length + 1, hparams.hidden_size)
elif hparams.pos == "timing_from_features":
positional_encoding = common_attention.add_timing_signals_from_features(
tf.zeros([1, decode_length, hparams.hidden_size]), features,
hparams.position_features)
elif hparams.pos == "emb":
positional_encoding = common_attention.add_positional_embedding(
tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length,
"body/targets_positional_embedding", None)
else:
positional_encoding = None
def preprocess_targets(targets, i):
"""Performs preprocessing steps on the targets to prepare for the decoder.
This includes:
- Embedding the ids.
- Flattening to 3D tensor.
- Optionally adding timing signals.
Args:
targets: inputs ids to the decoder. [batch_size, 1]
i: scalar, Step number of the decoding loop.
Returns:
Processed targets [batch_size, 1, hidden_dim]
"""
# _shard_features called to ensure that the variable names match
targets = self._shard_features({"targets": targets})["targets"]
modality_name = hparams.name.get(
"targets",
modalities.get_name(target_modality))(hparams, target_vocab_size)
with tf.variable_scope(modality_name):
bottom = hparams.bottom.get(
"targets", modalities.get_targets_bottom(target_modality))
targets = dp(bottom, targets, hparams, target_vocab_size)[0]
targets = common_layers.flatten4d3d(targets)
# GO embeddings are all zero, this is because transformer_prepare_decoder
# Shifts the targets along by one for the input which pads with zeros.
# If the modality already maps GO to the zero embeddings this is not
# needed.
if not self.get_decode_start_id():
targets = tf.cond(
tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
if positional_encoding is not None:
targets += positional_encoding[:, i:i + 1]
return targets
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(decode_length))
if hparams.proximity_bias:
decoder_self_attention_bias += common_attention.attention_bias_proximal(
decode_length)
# Create tensors for encoder-decoder attention history
att_cache = {"attention_history": {}}
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
if encoder_output is not None:
att_batch_size, enc_seq_length = common_layers.shape_list(
encoder_output)[0:2]
for layer in range(num_layers):
att_cache["attention_history"]["layer_%d" % layer] = tf.zeros(
[att_batch_size, hparams.num_heads, 0, enc_seq_length])
def update_decoder_attention_history(cache):
"""Save attention weights in cache, e.g., for vizualization."""
for k in [x for x in self.attention_weights
if "decoder" in x and "self" not in x and "logits" not in x]:
idx = k.find("layer_")
if idx < 0:
continue
# Get layer number from the string name.
layer_nbr = k[idx + 6:]
idx = 0
while idx + 1 < len(layer_nbr) and layer_nbr[:idx + 1].isdigit():
idx += 1
layer_nbr = "layer_%d" % int(layer_nbr[:idx])
if layer_nbr in cache["attention_history"]:
cache["attention_history"][layer_nbr] = tf.concat(
[cache["attention_history"][layer_nbr],
self.attention_weights[k]],
axis=2)
if not preprocess_targets_method:
preprocess_targets_method = preprocess_targets
def symbols_to_logits_fn(ids, i, cache):
"""Go from ids to logits for next symbol."""
ids = ids[:, -1:]
targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
targets = preprocess_targets_method(targets, i)
bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
with tf.variable_scope("body"):
body_outputs = dp(
self.decode,
targets,
cache.get("encoder_output"),
cache.get("encoder_decoder_attention_bias"),
bias,
hparams,
cache,
nonpadding=features_to_nonpadding(features, "targets"))
update_decoder_attention_history(cache)
modality_name = hparams.name.get(
"targets",
modalities.get_name(target_modality))(hparams, target_vocab_size)
with tf.variable_scope(modality_name):
top = hparams.top.get("targets", modalities.get_top(target_modality))
logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
ret = tf.squeeze(logits, axis=[1, 2, 3])
if partial_targets is not None:
# If the position is within the given partial targets, we alter the
# logits to always return those values.
# A faster approach would be to process the partial targets in one
# iteration in order to fill the corresponding parts of the cache.
# This would require broader changes, though.
vocab_size = tf.shape(ret)[1]
def forced_logits():
return tf.one_hot(
tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
-1e9)
ret = tf.cond(
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
return ret, cache
sos_id = self.get_decode_start_id() or 0
eos_id = self.get_decode_end_id() or beam_search.EOS_ID
temperature = features.get("sampling_temp",
getattr(hparams, "sampling_temp", 0.0))
top_k = features.get("sampling_keep_top_k",
getattr(hparams, "sampling_keep_top_k", -1))
ret = fast_decode(
encoder_output=encoder_output,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
symbols_to_logits_fn=symbols_to_logits_fn,
hparams=hparams,
decode_length=decode_length,
vocab_size=target_vocab_size,
init_cache_fn=self._init_cache_fn,
beam_size=beam_size,
top_beams=top_beams,
alpha=alpha,
batch_size=batch_size,
force_decode_length=self._decode_hparams.force_decode_length,
sos_id=sos_id,
eos_id=eos_id,
sampling_temperature=temperature,
top_k=top_k,
cache=att_cache)
if partial_targets is not None:
if beam_size <= 1 or top_beams <= 1:
ret["outputs"] = ret["outputs"][:, partial_targets_length:]
else:
ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
return ret