def _as_dataset()

in tf_agents/replay_buffers/episodic_replay_buffer.py [0:0]


  def _as_dataset(self,
                  sample_batch_size=None,
                  num_steps=None,
                  sequence_preprocess_fn=None,
                  num_parallel_calls=tf.data.experimental.AUTOTUNE):
    """Creates a dataset that returns episodes entries from the buffer.

    The dataset behaves differently depending on if `num_steps` is provided or
    not.  If `num_steps = None`, then entire episodes are sampled uniformly at
    random from the buffer.  If `num_steps != None`, then we attempt to sample
    uniformly across frames of all the episodes, and return subsets of length
    `num_steps`.  The algorithm for this is roughly:

    1. Sample an episode with a probability proportional to its length.
    2. If the length of the episode is less than `num_steps`, drop it.
    3. Sample a starting location `start` in `[0, len(episode) - num_steps]`
    4. Take a slice `[start, start + num_steps]`.

    The larger `num_steps` is, the higher the likelihood of edge effects (e.g.,
    certain frames not being visited often because they are near the start
    or end of an episode).  In the worst case, if `num_steps` is greater than
    most episode lengths, those episodes will never be visited.

    Args:
      sample_batch_size: (Optional.) An optional batch_size to specify the
        number of items to return. See as_dataset() documentation.
      num_steps: (Optional.) Scalar int.  How many contiguous frames to get
        per entry. Default is `None`: return full-length episodes.
      sequence_preprocess_fn: (Optional.) Preprocessing function for sequences
        before they are sharded into subsequences of length `num_steps` and
        batched.
      num_parallel_calls: Number of parallel calls to use in the
        dataset pipeline when extracting episodes.  Default is to have
        tensorflow determine the optimal number of calls.

    Returns:
      A dataset of type tf.data.Dataset, elements of which are 2-tuples of:

        - An item or sequence of items sampled uniformly from the buffer.
        - BufferInfo NamedTuple, containing the episode id.

    Raises:
      ValueError: If the data spec contains lists that must be converted to
        tuples.
      NotImplementedError: If `sequence_preprocess_fn != None` is passed in.
    """
    if sequence_preprocess_fn is not None:
      raise NotImplementedError('sequence_preprocess_fn is not supported.')

    # data_tf.nest.flatten does not flatten python lists, tf.nest.flatten does.
    if tf.nest.flatten(self._data_spec) != data_nest.flatten(self._data_spec):
      raise ValueError(
          'Cannot perform gather; data spec contains lists and this conflicts '
          'with gathering operator.  Convert any lists to tuples.  '
          'For example, if your spec looks like [a, b, c], '
          'change it to (a, b, c).  Spec structure is:\n  {}'.format(
              tf.nest.map_structure(lambda spec: spec.dtype, self._data_spec)))

    seed_per_episode = distributions_util.gen_new_seed(
        self._seed,
        salt='per_episode')

    episode_id_buffer_size = self._buffer_size * (sample_batch_size or 1)

    def _get_episode_locations(_):
      """Sample episode ids according to value of num_steps."""
      if num_steps is None:
        # Just want to get a uniform sampling of episodes.
        episode_ids = self._sample_episode_ids(
            shape=[episode_id_buffer_size], seed=self._seed)
      else:
        # Want to try to sample uniformly from frames, which means
        # sampling episodes by length.
        episode_ids = self._sample_episode_ids(
            shape=[episode_id_buffer_size],
            weigh_by_episode_length=True,
            seed=self._seed)
      episode_locations = self._get_episode_id_location(episode_ids)

      if self._completed_only:
        return tf.boolean_mask(
            tensor=episode_locations,
            mask=self._episode_completed.sparse_read(episode_locations))
      else:
        return episode_locations

    ds = tf.data.experimental.Counter().map(_get_episode_locations).unbatch()

    if num_steps is None:
      @tf.autograph.experimental.do_not_convert
      def _read_data_and_id(row):
        return (
            self._data_table.get_episode_values(row),
            self._id_table.read(row))
      ds = ds.map(_read_data_and_id, num_parallel_calls=num_parallel_calls)
    else:
      @tf.autograph.experimental.do_not_convert
      def _read_tensor_list_and_id(row):
        """Read the TensorLists out of the table row, get id and num_frames."""
        # Return a flattened tensor list
        flat_tensor_lists = tuple(
            tf.nest.flatten(self._data_table.get_episode_lists(row)))
        # Due to race conditions, not all entries may have been written for the
        # given episode.  Use the minimum list length to identify the full valid
        # available length.
        num_frames = tf.reduce_min(
            [list_ops.tensor_list_length(l) for l in flat_tensor_lists])
        return flat_tensor_lists, self._id_table.read(row), num_frames

      ds = ds.map(
          _read_tensor_list_and_id, num_parallel_calls=num_parallel_calls)

      def _filter_by_length(unused_1, unused_2, num_frames):
        # Remove episodes that are too short.
        return num_frames >= num_steps

      ds = ds.filter(_filter_by_length)

      @tf.autograph.experimental.do_not_convert
      def _random_slice(flat_tensor_lists, id_, num_frames):
        """Take a random slice from the episode, of length num_steps."""
        # Sample uniformly between [0, num_frames - num_steps]
        start_slice = tf.random.uniform((),
                                        minval=0,
                                        maxval=num_frames - num_steps + 1,
                                        dtype=tf.int32,
                                        seed=seed_per_episode)
        end_slice = start_slice + num_steps

        flat_spec = tf.nest.flatten(self._data_spec)

        # Pull out frames in [start_slice, start_slice + num_steps]
        flat = tuple(
            list_ops.tensor_list_gather(  # pylint: disable=g-complex-comprehension
                t, indices=tf.range(start_slice, end_slice),
                element_dtype=spec.dtype, element_shape=spec.shape)
            for t, spec in zip(flat_tensor_lists, flat_spec))
        return flat, id_

      ds = ds.map(_random_slice, num_parallel_calls=num_parallel_calls)

      def set_shape_and_restore_structure(flat_data, id_):
        def restore_shape(t_sliced):
          if t_sliced.shape.rank is not None:
            t_sliced.set_shape([num_steps] + [None] * (t_sliced.shape.rank - 1))
            return t_sliced
        shaped_flat = [restore_shape(x) for x in flat_data]
        return tf.nest.pack_sequence_as(self._data_spec, shaped_flat), id_

      ds = ds.map(set_shape_and_restore_structure)

    if sample_batch_size:
      if num_steps is None:
        raise ValueError("""`num_steps` must be set if `sample_batch_size` is