in tf_agents/replay_buffers/episodic_replay_buffer.py [0:0]
def _as_dataset(self,
sample_batch_size=None,
num_steps=None,
sequence_preprocess_fn=None,
num_parallel_calls=tf.data.experimental.AUTOTUNE):
"""Creates a dataset that returns episodes entries from the buffer.
The dataset behaves differently depending on if `num_steps` is provided or
not. If `num_steps = None`, then entire episodes are sampled uniformly at
random from the buffer. If `num_steps != None`, then we attempt to sample
uniformly across frames of all the episodes, and return subsets of length
`num_steps`. The algorithm for this is roughly:
1. Sample an episode with a probability proportional to its length.
2. If the length of the episode is less than `num_steps`, drop it.
3. Sample a starting location `start` in `[0, len(episode) - num_steps]`
4. Take a slice `[start, start + num_steps]`.
The larger `num_steps` is, the higher the likelihood of edge effects (e.g.,
certain frames not being visited often because they are near the start
or end of an episode). In the worst case, if `num_steps` is greater than
most episode lengths, those episodes will never be visited.
Args:
sample_batch_size: (Optional.) An optional batch_size to specify the
number of items to return. See as_dataset() documentation.
num_steps: (Optional.) Scalar int. How many contiguous frames to get
per entry. Default is `None`: return full-length episodes.
sequence_preprocess_fn: (Optional.) Preprocessing function for sequences
before they are sharded into subsequences of length `num_steps` and
batched.
num_parallel_calls: Number of parallel calls to use in the
dataset pipeline when extracting episodes. Default is to have
tensorflow determine the optimal number of calls.
Returns:
A dataset of type tf.data.Dataset, elements of which are 2-tuples of:
- An item or sequence of items sampled uniformly from the buffer.
- BufferInfo NamedTuple, containing the episode id.
Raises:
ValueError: If the data spec contains lists that must be converted to
tuples.
NotImplementedError: If `sequence_preprocess_fn != None` is passed in.
"""
if sequence_preprocess_fn is not None:
raise NotImplementedError('sequence_preprocess_fn is not supported.')
# data_tf.nest.flatten does not flatten python lists, tf.nest.flatten does.
if tf.nest.flatten(self._data_spec) != data_nest.flatten(self._data_spec):
raise ValueError(
'Cannot perform gather; data spec contains lists and this conflicts '
'with gathering operator. Convert any lists to tuples. '
'For example, if your spec looks like [a, b, c], '
'change it to (a, b, c). Spec structure is:\n {}'.format(
tf.nest.map_structure(lambda spec: spec.dtype, self._data_spec)))
seed_per_episode = distributions_util.gen_new_seed(
self._seed,
salt='per_episode')
episode_id_buffer_size = self._buffer_size * (sample_batch_size or 1)
def _get_episode_locations(_):
"""Sample episode ids according to value of num_steps."""
if num_steps is None:
# Just want to get a uniform sampling of episodes.
episode_ids = self._sample_episode_ids(
shape=[episode_id_buffer_size], seed=self._seed)
else:
# Want to try to sample uniformly from frames, which means
# sampling episodes by length.
episode_ids = self._sample_episode_ids(
shape=[episode_id_buffer_size],
weigh_by_episode_length=True,
seed=self._seed)
episode_locations = self._get_episode_id_location(episode_ids)
if self._completed_only:
return tf.boolean_mask(
tensor=episode_locations,
mask=self._episode_completed.sparse_read(episode_locations))
else:
return episode_locations
ds = tf.data.experimental.Counter().map(_get_episode_locations).unbatch()
if num_steps is None:
@tf.autograph.experimental.do_not_convert
def _read_data_and_id(row):
return (
self._data_table.get_episode_values(row),
self._id_table.read(row))
ds = ds.map(_read_data_and_id, num_parallel_calls=num_parallel_calls)
else:
@tf.autograph.experimental.do_not_convert
def _read_tensor_list_and_id(row):
"""Read the TensorLists out of the table row, get id and num_frames."""
# Return a flattened tensor list
flat_tensor_lists = tuple(
tf.nest.flatten(self._data_table.get_episode_lists(row)))
# Due to race conditions, not all entries may have been written for the
# given episode. Use the minimum list length to identify the full valid
# available length.
num_frames = tf.reduce_min(
[list_ops.tensor_list_length(l) for l in flat_tensor_lists])
return flat_tensor_lists, self._id_table.read(row), num_frames
ds = ds.map(
_read_tensor_list_and_id, num_parallel_calls=num_parallel_calls)
def _filter_by_length(unused_1, unused_2, num_frames):
# Remove episodes that are too short.
return num_frames >= num_steps
ds = ds.filter(_filter_by_length)
@tf.autograph.experimental.do_not_convert
def _random_slice(flat_tensor_lists, id_, num_frames):
"""Take a random slice from the episode, of length num_steps."""
# Sample uniformly between [0, num_frames - num_steps]
start_slice = tf.random.uniform((),
minval=0,
maxval=num_frames - num_steps + 1,
dtype=tf.int32,
seed=seed_per_episode)
end_slice = start_slice + num_steps
flat_spec = tf.nest.flatten(self._data_spec)
# Pull out frames in [start_slice, start_slice + num_steps]
flat = tuple(
list_ops.tensor_list_gather( # pylint: disable=g-complex-comprehension
t, indices=tf.range(start_slice, end_slice),
element_dtype=spec.dtype, element_shape=spec.shape)
for t, spec in zip(flat_tensor_lists, flat_spec))
return flat, id_
ds = ds.map(_random_slice, num_parallel_calls=num_parallel_calls)
def set_shape_and_restore_structure(flat_data, id_):
def restore_shape(t_sliced):
if t_sliced.shape.rank is not None:
t_sliced.set_shape([num_steps] + [None] * (t_sliced.shape.rank - 1))
return t_sliced
shaped_flat = [restore_shape(x) for x in flat_data]
return tf.nest.pack_sequence_as(self._data_spec, shaped_flat), id_
ds = ds.map(set_shape_and_restore_structure)
if sample_batch_size:
if num_steps is None:
raise ValueError("""`num_steps` must be set if `sample_batch_size` is