in tensorflow_addons/seq2seq/decoder.py [0:0]
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
"""Internal while_loop body.
Args:
time: scalar int32 tensor.
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of finish).
Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
```
"""
(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
time, inputs, state, training
)
decoder_state_sequence_lengths = False
if decoder.tracks_own_finished:
next_finished = decoder_finished
lengths = getattr(decoder_state, "lengths", None)
if lengths is not None:
# sequence lengths are provided by decoder_state.lengths;
# overwrite our sequence lengths.
decoder_state_sequence_lengths = True
sequence_lengths = tf.cast(lengths, tf.int32)
else:
next_finished = tf.logical_or(decoder_finished, finished)
if decoder_state_sequence_lengths:
# Just pass something through the loop; at the next iteration
# we'll pull the sequence lengths from the decoder_state again.
next_sequence_lengths = sequence_lengths
else:
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths,
)
tf.nest.assert_same_structure(state, decoder_state)
tf.nest.assert_same_structure(outputs_ta, next_outputs)
tf.nest.assert_same_structure(inputs, next_inputs)
# Zero out output values past finish
if impute_finished:
def zero_out_finished(out, zero):
if finished.shape.rank < zero.shape.rank:
broadcast_finished = tf.broadcast_to(
tf.expand_dims(finished, axis=-1), zero.shape
)
return tf.where(broadcast_finished, zero, out)
else:
return tf.where(finished, zero, out)
emit = tf.nest.map_structure(
zero_out_finished, next_outputs, zero_outputs
)
else:
emit = next_outputs
# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tf.TensorArray):
pass_through = True
else:
new.set_shape(cur.shape)
pass_through = new.shape.ndims == 0
if not pass_through:
broadcast_finished = tf.broadcast_to(
tf.expand_dims(finished, axis=-1), new.shape
)
return tf.where(broadcast_finished, cur, new)
else:
return new
if impute_finished:
next_state = tf.nest.map_structure(
_maybe_copy_state, decoder_state, state
)
else:
next_state = decoder_state
if enable_tflite_convertible:
# Reshape to 1-D.
emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit)
outputs_ta = tf.nest.map_structure(
lambda ta, out: ta.write(time, out), outputs_ta, emit
)
return (
time + 1,
outputs_ta,
next_state,
next_inputs,
next_finished,
next_sequence_lengths,
)