in tf_agents/replay_buffers/tf_uniform_replay_buffer.py [0:0]
def _single_deterministic_pass_dataset(self,
sample_batch_size=None,
num_steps=None,
sequence_preprocess_fn=None,
num_parallel_calls=None):
"""Creates a dataset that returns entries from the buffer in fixed order.
Args:
sample_batch_size: (Optional.) An optional batch_size to specify the
number of items to return. See as_dataset() documentation.
num_steps: (Optional.) Optional way to specify that sub-episodes are
desired. See as_dataset() documentation.
sequence_preprocess_fn: (Optional.) Preprocessing function for sequences
before they are sharded into subsequences of length `num_steps` and
batched.
num_parallel_calls: (Optional.) Number elements to process in parallel.
See as_dataset() documentation.
Returns:
A dataset of type tf.data.Dataset, elements of which are 2-tuples of:
- An item or sequence of items or batch thereof
- Auxiliary info for the items (i.e. ids, probs).
Raises:
ValueError: If `dataset_drop_remainder` is set, and
`sample_batch_size > self.batch_size`. In this case all data will
be dropped.
NotImplementedError: If `sequence_preprocess_fn != None` is passed in.
"""
if sequence_preprocess_fn is not None:
raise NotImplementedError('sequence_preprocess_fn is not supported.')
static_size = tf.get_static_value(sample_batch_size)
static_num_steps = tf.get_static_value(num_steps)
static_self_batch_size = tf.get_static_value(self._batch_size)
static_self_max_length = tf.get_static_value(self._max_length)
if (self._dataset_drop_remainder
and static_size is not None
and static_self_batch_size is not None
and static_size > static_self_batch_size):
raise ValueError(
'sample_batch_size ({}) > self.batch_size ({}) and '
'dataset_drop_remainder is True. In '
'this case, ALL data will be dropped by the deterministic dataset.'
.format(static_size, static_self_batch_size))
if (self._dataset_drop_remainder
and static_num_steps is not None
and static_self_max_length is not None
and static_num_steps > static_self_max_length):
raise ValueError(
'num_steps_size ({}) > self.max_length ({}) and '
'dataset_drop_remainder is True. In '
'this case, ALL data will be dropped by the deterministic dataset.'
.format(static_num_steps, static_self_max_length))
def get_row_ids(_):
"""Passed to Dataset.range(self._batch_size).flat_map(.), gets row ids."""
with tf.device(self._device), tf.name_scope(self._scope):
with tf.name_scope('single_deterministic_pass_dataset'):
# Here we pass num_steps=None because _valid_range_ids uses
# num_steps to determine a hard stop when sampling num_steps starting
# from the returned indices. But in our case, we want all the indices
# and we'll use TF dataset's window() mechanism to get
# num_steps-length blocks. The window mechanism handles this stuff
# for us.
min_frame_offset, max_frame_offset = _valid_range_ids(
self._get_last_id(), self._max_length, num_steps=None)
tf.compat.v1.assert_less(
min_frame_offset,
max_frame_offset,
message='TFUniformReplayBuffer is empty. Make sure to add items '
'before asking the buffer for data.')
min_max_frame_range = tf.range(min_frame_offset, max_frame_offset)
window_shift = self._dataset_window_shift
def group_windows(ds_, drop_remainder=self._dataset_drop_remainder):
return ds_.batch(num_steps, drop_remainder=drop_remainder)
if sample_batch_size is None:
def row_ids(b):
# Create a vector of shape [num_frames] and slice it along each
# frame.
ids = tf.data.Dataset.from_tensor_slices(
b * self._max_length + min_max_frame_range)
if num_steps is not None:
ids = (ids.window(num_steps, shift=window_shift)
.flat_map(group_windows))
return ids
return tf.data.Dataset.range(self._batch_size).flat_map(row_ids)
else:
def batched_row_ids(batch):
# Create a matrix of indices shaped [num_frames, batch_size]
# and slice it along each frame row to get groups of batches
# for frame 0, frame 1, ...
return tf.data.Dataset.from_tensor_slices(
(min_max_frame_range[:, tf.newaxis]
+ batch * self._max_length))
indices_ds = (
tf.data.Dataset.range(self._batch_size)
.batch(sample_batch_size,
drop_remainder=self._dataset_drop_remainder)
.flat_map(batched_row_ids))
if num_steps is not None:
# We have sequences of num_frames rows shaped [sample_batch_size].
# Window and group these to rows of shape
# [num_steps, sample_batch_size], then
# transpose them to get index tensors of shape
# [sample_batch_size, num_steps].
def group_windows_drop_remainder(d):
return group_windows(d, drop_remainder=True)
indices_ds = (indices_ds.window(num_steps, shift=window_shift)
.flat_map(group_windows_drop_remainder)
.map(tf.transpose))
return indices_ds
# Get our indices as a dataset; each time we reinitialize the iterator we
# update our min/max id bounds from the state of the replay buffer.
ds = tf.data.Dataset.range(1).flat_map(get_row_ids)
def get_data(id_):
with tf.device(self._device), tf.name_scope(self._scope):
with tf.name_scope('single_deterministic_pass_dataset'):
data = self._data_table.read(id_ % self._capacity)
buffer_info = BufferInfo(ids=id_, probabilities=())
return (data, buffer_info)
# Deterministic even though num_parallel_calls > 1. Operations are
# run in parallel but then the results are returned in original stream
# order.
ds = ds.map(get_data, num_parallel_calls=num_parallel_calls)
return ds