in mesh_tensorflow/transformer/transformer.py [0:0]
def sample_autoregressive(self,
partial_sequences,
stop_at_token=1,
max_steps=None,
temperature=0.0,
variable_dtype=mtf.VariableDType(tf.float32),
encoder_output=None,
encoder_sequence_id=None,
encoder_inputs=None,
shared_params=None,
has_partial_sequences=True,
encoder_layer_outputs=None,
never_end=False,
remove_partial_sequences=False,
sampling_keep_top_k=-1,
bos_id=0):
"""Sample randomly one token at a time.
The partial_sequences represent partial sequences to be continued. The
first tokens of each sequence are nonzero representing the given partial
sequences and the last tokens of each sequence are zeros, representing what
needs to be filled in.
If there are no partial sequences (you want to sample from the beginning),
then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
has_partial_sequences=False (so we can skip computation).
Args:
partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
stop_at_token: an optional integer eos id. Stop when we produce it.
max_steps: an optional integer, the max number of steps to decode.
temperature: an optional floating point value between 0.0 and 1.0 0.0
means argmax, 1.0 means sample according to predicted distribution.
variable_dtype: a mtf.VariableDType
encoder_output: an optional Tensor
encoder_sequence_id: an optional Tensor
encoder_inputs: an optional Tensor
shared_params: an optional dictionary
has_partial_sequences: a boolean
encoder_layer_outputs: optional - readonly list of tensor activations when
decoding, one per each input layer + the embedding layer
never_end: a boolean - if set, then avoid generating stop_at_token
remove_partial_sequences: a boolean - whether to remove the partial
sequences from the output
sampling_keep_top_k: an integer - if not -1, only sample from the top k
logits.
bos_id: beginning of sequence id
Returns:
a Tensor with shape [<batch_dims>, length_dim]
"""
if not self.autoregressive:
raise ValueError("must be autoregressive")
inputs = partial_sequences
batch_dims = inputs.shape.dims[:-1]
length_dim = inputs.shape.dims[-1]
initial_position = mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
sequence_id = 1 if encoder_sequence_id is not None else None
length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
if self.input_full_attention:
read_priority = write_priority = length_range * mtf.to_int32(
mtf.greater(length_range, initial_position))
else:
read_priority = write_priority = length_range
context_first_part = Context(
model=self,
mesh=inputs.mesh,
batch_dims=batch_dims,
length_dim=length_dim,
variable_dtype=variable_dtype,
mode="first_part",
position=length_range,
position_is_default=True,
new_states=[],
initial_position=initial_position,
sequence_id=sequence_id,
encoder_output=encoder_output,
encoder_sequence_id=encoder_sequence_id,
constant_states=[],
shared_params=shared_params,
encoder_layer_outputs=encoder_layer_outputs,
write_priority=write_priority,
read_priority=read_priority,
inputs=inputs,
encoder_inputs=encoder_inputs)
shifted_inputs = autoregressive_inputs(inputs)
with tf.variable_scope(self.name):
logits = self._call_internal(context_first_part, shifted_inputs)
del logits
constant_states = context_first_part.constant_states
if not has_partial_sequences:
initial_states = [
mtf.zeros_like(t) for t in context_first_part.new_states]
partial_sequences_eos_count = 0
else:
initial_states = context_first_part.new_states
partial_sequences_eos_count = mtf.reduce_sum(
mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
reduced_dim=length_dim)
def cond_fn(position, ids, *unused_states):
"""Should we run another loop iteration."""
past_end = mtf.greater_equal(position, length_dim.size)
if max_steps:
past_end = mtf.logical_or(
past_end, mtf.greater_equal(position - initial_position, max_steps))
is_done = past_end
if stop_at_token is not None:
eos_count = mtf.reduce_sum(
mtf.to_int32(mtf.equal(ids, stop_at_token)),
reduced_dim=length_dim)
has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
is_done = mtf.logical_or(is_done, has_additional_eos)
all_done = mtf.reduce_all(is_done)
return mtf.logical_not(all_done)
def body_fn(position, ids, *states):
"""One step in the decode loop."""
inputs_this_step = mtf.gather(ids, position - 1, length_dim)
# Setting proper bos_id for position == 0. No-op otherwise.
if bos_id:
inputs_this_step += bos_id * mtf.ones_like(inputs_this_step) * mtf.cast(
mtf.equal(position, 0), tf.int32)
context_incremental = Context(
model=self,
mesh=inputs.mesh,
batch_dims=batch_dims,
length_dim=length_dim,
variable_dtype=variable_dtype,
mode="incremental",
position=position,
states=states,
new_states=[],
sequence_id=sequence_id,
encoder_output=encoder_output,
encoder_sequence_id=encoder_sequence_id,
constant_states=constant_states,
shared_params=shared_params,
encoder_layer_outputs=encoder_layer_outputs,
write_priority=write_priority,
read_priority=position,
inputs=inputs_this_step,
encoder_inputs=encoder_inputs)
with tf.variable_scope(self.name, reuse=True):
logits = self._call_internal(context_incremental, inputs_this_step)
if never_end:
logits += mtf.one_hot(
mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32),
self.output_vocab_dim, on_value=-1e9, off_value=0.0,
dtype=logits.dtype)
# TBD whether this should be before or after never_end:
# Note for adding top_p sampling in the future, in other code bases, the
# option to apply temperature is done before the top-k truncation. This
# implementation does this in the opposite order. For top-k this doesn't
# matter, but for top_p it will.
if sampling_keep_top_k != -1:
if sampling_keep_top_k <= 0:
raise ValueError("sampling_keep_top_k must either be -1 or positive.")
k_largest = mtf.nth_largest_element(
logits, n=sampling_keep_top_k,
reduced_dim=self.output_vocab_dim)
logits = mtf.where(mtf.less_equal(logits, k_largest),
mtf.ones_like(logits)*-1e6, logits)
ids_this_step = mtf.sample_with_temperature(
logits, self.output_vocab_dim, temperature)
new_position = position + 1
new_ids = ids + ids_this_step * mtf.one_hot(
position, length_dim, dtype=tf.int32)
return [new_position, new_ids] + context_incremental.new_states
while_loop_inputs = [initial_position, inputs] + initial_states
final_position, outputs = mtf.while_loop(
cond_fn, body_fn, while_loop_inputs)[:2]
del final_position
if has_partial_sequences and remove_partial_sequences:
# remove partial sequences from outputs
partial_length = mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(partial_sequences, 0)),
reduced_dim=length_dim)
outputs = mtf.dynamic_shift(
outputs, -partial_length, length_dim, wrap=False)
return outputs