def body()

in tensorflow_addons/seq2seq/decoder.py [0:0]


        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

            Args:
              time: scalar int32 tensor.
              outputs_ta: structure of TensorArray.
              state: (structure of) state tensors and TensorArrays.
              inputs: (structure of) input tensors.
              finished: bool tensor (keeping track of what's finished).
              sequence_lengths: int32 tensor (keeping track of time of finish).

            Returns:
              `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
                next_sequence_lengths)`.
              ```
            """
            (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
                time, inputs, state, training
            )
            decoder_state_sequence_lengths = False
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
                lengths = getattr(decoder_state, "lengths", None)
                if lengths is not None:
                    # sequence lengths are provided by decoder_state.lengths;
                    # overwrite our sequence lengths.
                    decoder_state_sequence_lengths = True
                    sequence_lengths = tf.cast(lengths, tf.int32)
            else:
                next_finished = tf.logical_or(decoder_finished, finished)

            if decoder_state_sequence_lengths:
                # Just pass something through the loop; at the next iteration
                # we'll pull the sequence lengths from the decoder_state again.
                next_sequence_lengths = sequence_lengths
            else:
                next_sequence_lengths = tf.where(
                    tf.logical_not(finished),
                    tf.fill(tf.shape(sequence_lengths), time + 1),
                    sequence_lengths,
                )

            tf.nest.assert_same_structure(state, decoder_state)
            tf.nest.assert_same_structure(outputs_ta, next_outputs)
            tf.nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:

                def zero_out_finished(out, zero):
                    if finished.shape.rank < zero.shape.rank:
                        broadcast_finished = tf.broadcast_to(
                            tf.expand_dims(finished, axis=-1), zero.shape
                        )
                        return tf.where(broadcast_finished, zero, out)
                    else:
                        return tf.where(finished, zero, out)

                emit = tf.nest.map_structure(
                    zero_out_finished, next_outputs, zero_outputs
                )
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tf.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = new.shape.ndims == 0
                if not pass_through:
                    broadcast_finished = tf.broadcast_to(
                        tf.expand_dims(finished, axis=-1), new.shape
                    )
                    return tf.where(broadcast_finished, cur, new)
                else:
                    return new

            if impute_finished:
                next_state = tf.nest.map_structure(
                    _maybe_copy_state, decoder_state, state
                )
            else:
                next_state = decoder_state

            if enable_tflite_convertible:
                # Reshape to 1-D.
                emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit)

            outputs_ta = tf.nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit
            )
            return (
                time + 1,
                outputs_ta,
                next_state,
                next_inputs,
                next_finished,
                next_sequence_lengths,
            )