def _single_deterministic_pass_dataset()

in tf_agents/replay_buffers/tf_uniform_replay_buffer.py [0:0]


  def _single_deterministic_pass_dataset(self,
                                         sample_batch_size=None,
                                         num_steps=None,
                                         sequence_preprocess_fn=None,
                                         num_parallel_calls=None):
    """Creates a dataset that returns entries from the buffer in fixed order.

    Args:
      sample_batch_size: (Optional.) An optional batch_size to specify the
        number of items to return. See as_dataset() documentation.
      num_steps: (Optional.)  Optional way to specify that sub-episodes are
        desired. See as_dataset() documentation.
      sequence_preprocess_fn: (Optional.) Preprocessing function for sequences
        before they are sharded into subsequences of length `num_steps` and
        batched.
      num_parallel_calls: (Optional.) Number elements to process in parallel.
        See as_dataset() documentation.

    Returns:
      A dataset of type tf.data.Dataset, elements of which are 2-tuples of:

        - An item or sequence of items or batch thereof
        - Auxiliary info for the items (i.e. ids, probs).

    Raises:
      ValueError: If `dataset_drop_remainder` is set, and
        `sample_batch_size > self.batch_size`.  In this case all data will
        be dropped.
      NotImplementedError: If `sequence_preprocess_fn != None` is passed in.
    """
    if sequence_preprocess_fn is not None:
      raise NotImplementedError('sequence_preprocess_fn is not supported.')
    static_size = tf.get_static_value(sample_batch_size)
    static_num_steps = tf.get_static_value(num_steps)
    static_self_batch_size = tf.get_static_value(self._batch_size)
    static_self_max_length = tf.get_static_value(self._max_length)
    if (self._dataset_drop_remainder
        and static_size is not None
        and static_self_batch_size is not None
        and static_size > static_self_batch_size):
      raise ValueError(
          'sample_batch_size ({}) > self.batch_size ({}) and '
          'dataset_drop_remainder is True.  In '
          'this case, ALL data will be dropped by the deterministic dataset.'
          .format(static_size, static_self_batch_size))
    if (self._dataset_drop_remainder
        and static_num_steps is not None
        and static_self_max_length is not None
        and static_num_steps > static_self_max_length):
      raise ValueError(
          'num_steps_size ({}) > self.max_length ({}) and '
          'dataset_drop_remainder is True.  In '
          'this case, ALL data will be dropped by the deterministic dataset.'
          .format(static_num_steps, static_self_max_length))

    def get_row_ids(_):
      """Passed to Dataset.range(self._batch_size).flat_map(.), gets row ids."""
      with tf.device(self._device), tf.name_scope(self._scope):
        with tf.name_scope('single_deterministic_pass_dataset'):
          # Here we pass num_steps=None because _valid_range_ids uses
          # num_steps to determine a hard stop when sampling num_steps starting
          # from the returned indices.  But in our case, we want all the indices
          # and we'll use TF dataset's window() mechanism to get
          # num_steps-length blocks.  The window mechanism handles this stuff
          # for us.
          min_frame_offset, max_frame_offset = _valid_range_ids(
              self._get_last_id(), self._max_length, num_steps=None)
          tf.compat.v1.assert_less(
              min_frame_offset,
              max_frame_offset,
              message='TFUniformReplayBuffer is empty. Make sure to add items '
              'before asking the buffer for data.')

          min_max_frame_range = tf.range(min_frame_offset, max_frame_offset)

          window_shift = self._dataset_window_shift
          def group_windows(ds_, drop_remainder=self._dataset_drop_remainder):
            return ds_.batch(num_steps, drop_remainder=drop_remainder)

          if sample_batch_size is None:
            def row_ids(b):
              # Create a vector of shape [num_frames] and slice it along each
              # frame.
              ids = tf.data.Dataset.from_tensor_slices(
                  b * self._max_length + min_max_frame_range)
              if num_steps is not None:
                ids = (ids.window(num_steps, shift=window_shift)
                       .flat_map(group_windows))
              return ids
            return tf.data.Dataset.range(self._batch_size).flat_map(row_ids)
          else:
            def batched_row_ids(batch):
              # Create a matrix of indices shaped [num_frames, batch_size]
              # and slice it along each frame row to get groups of batches
              # for frame 0, frame 1, ...
              return tf.data.Dataset.from_tensor_slices(
                  (min_max_frame_range[:, tf.newaxis]
                   + batch * self._max_length))

            indices_ds = (
                tf.data.Dataset.range(self._batch_size)
                .batch(sample_batch_size,
                       drop_remainder=self._dataset_drop_remainder)
                .flat_map(batched_row_ids))

            if num_steps is not None:
              # We have sequences of num_frames rows shaped [sample_batch_size].
              # Window and group these to rows of shape
              # [num_steps, sample_batch_size], then
              # transpose them to get index tensors of shape
              # [sample_batch_size, num_steps].
              def group_windows_drop_remainder(d):
                return group_windows(d, drop_remainder=True)

              indices_ds = (indices_ds.window(num_steps, shift=window_shift)
                            .flat_map(group_windows_drop_remainder)
                            .map(tf.transpose))

            return indices_ds

    # Get our indices as a dataset; each time we reinitialize the iterator we
    # update our min/max id bounds from the state of the replay buffer.
    ds = tf.data.Dataset.range(1).flat_map(get_row_ids)

    def get_data(id_):
      with tf.device(self._device), tf.name_scope(self._scope):
        with tf.name_scope('single_deterministic_pass_dataset'):
          data = self._data_table.read(id_ % self._capacity)
      buffer_info = BufferInfo(ids=id_, probabilities=())
      return (data, buffer_info)

    # Deterministic even though num_parallel_calls > 1.  Operations are
    # run in parallel but then the results are returned in original stream
    # order.
    ds = ds.map(get_data, num_parallel_calls=num_parallel_calls)

    return ds