def initialize()

in tensorflow_addons/seq2seq/sampler.py [0:0]


    def initialize(self, inputs, sequence_length=None, mask=None):
        """Initialize the TrainSampler.

        Args:
          inputs: A (structure of) input tensors.
          sequence_length: An int32 vector tensor.
          mask: A boolean 2D tensor.

        Returns:
          (finished, next_inputs), a tuple of two items. The first item is a
            boolean vector to indicate whether the item in the batch has
            finished. The second item is the first slide of input data based on
            the timestep dimension (usually the second dim of the input).
        """
        self.inputs = tf.convert_to_tensor(inputs, name="inputs")
        if not self.time_major:
            inputs = tf.nest.map_structure(_transpose_batch_time, inputs)

        self._batch_size = tf.shape(tf.nest.flatten(inputs)[0])[1]

        self.input_tas = tf.nest.map_structure(_unstack_ta, inputs)
        if sequence_length is not None and mask is not None:
            raise ValueError(
                "sequence_length and mask can't be provided at the same time."
            )
        if sequence_length is not None:
            self.sequence_length = tf.convert_to_tensor(
                sequence_length, name="sequence_length"
            )
            if self.sequence_length.shape.ndims != 1:
                raise ValueError(
                    "Expected sequence_length to be vector, but received "
                    "shape: %s" % self.sequence_length.shape
                )
        elif mask is not None:
            mask = tf.convert_to_tensor(mask)
            if mask.shape.ndims != 2:
                raise ValueError(
                    "Expected mask to a 2D tensor, but received shape: %s" % mask
                )
            if not mask.dtype.is_bool:
                raise ValueError(
                    "Expected mask to be a boolean tensor, but received "
                    "dtype: %s" % repr(mask.dtype)
                )

            axis = 1 if not self.time_major else 0
            with tf.control_dependencies(
                [_check_sequence_is_right_padded(mask, self.time_major)]
            ):
                self.sequence_length = tf.math.reduce_sum(
                    tf.cast(mask, tf.int32), axis=axis, name="sequence_length"
                )
        else:
            # As the input tensor has been converted to time major,
            # the maximum sequence length should be inferred from
            # the first dimension.
            max_seq_len = tf.shape(tf.nest.flatten(inputs)[0])[0]
            self.sequence_length = tf.fill(
                [self.batch_size], max_seq_len, name="sequence_length"
            )

        self.zero_inputs = tf.nest.map_structure(
            lambda inp: tf.zeros_like(inp[0, :]), inputs
        )

        finished = tf.equal(0, self.sequence_length)
        all_finished = tf.reduce_all(finished)
        next_inputs = tf.cond(
            all_finished,
            lambda: self.zero_inputs,
            lambda: tf.nest.map_structure(lambda inp: inp.read(0), self.input_tas),
        )
        return (finished, next_inputs)