in reagent/replay_memory/circular_replay_buffer.py [0:0]
def sample_transition_batch(self, batch_size=None, indices=None):
"""Returns a batch of transitions (including any extra contents).
If get_transition_elements has been overridden and defines elements not
stored in self._store, None will be returned and it will be
left to the child class to fill it. For example, for the child class
PrioritizedReplayBuffer, the contents of the
sampling_probabilities are stored separately in a sum tree.
When the transition is terminal next_state_batch has undefined contents.
NOTE: This transition contains the indices of the sampled elements. These
are only valid during the call to sample_transition_batch, i.e. they may
be used by subclasses of this replay buffer but may point to different data
as soon as sampling is done.
NOTE: Tensors are reshaped. I.e., state is 2-D unless stack_size > 1.
Scalar values are returned as (batch_size, 1) instead of (batch_size,).
Args:
batch_size: int, number of transitions returned. If None, the default
batch_size will be used.
indices: None or Tensor, the indices of every transition in the
batch. If None, sample the indices uniformly.
Returns:
transition_batch: tuple of Tensors with the shape and type as in
get_transition_elements().
Raises:
ValueError: If an element to be sampled is missing from the replay buffer.
"""
if batch_size is None:
batch_size = self._batch_size
if indices is None:
indices = self.sample_index_batch(batch_size)
else:
assert isinstance(
indices, torch.Tensor
), f"Indices {indices} have type {type(indices)} instead of torch.Tensor"
indices = indices.type(dtype=torch.int64)
assert len(indices) == batch_size
# calculate 2d array of indices with size (batch_size, update_horizon)
# ith row contain the multistep indices starting at indices[i]
multistep_indices = indices.unsqueeze(1) + torch.arange(self._update_horizon)
multistep_indices %= self._replay_capacity
steps = self._get_steps(multistep_indices)
# to pass in to next_features and reward to toggle whether to return
# a list batch of length steps.
if self._return_as_timeline_format:
next_indices = (indices + 1) % self._replay_capacity
steps_for_timeline_format = steps
else:
next_indices = (indices + steps) % self._replay_capacity
steps_for_timeline_format = None
batch_arrays = []
for element_name in self._transition_elements:
if element_name == "state":
batch = self._get_batch_for_indices("observation", indices)
elif element_name == "next_state":
batch = self._get_batch_for_indices(
"observation", next_indices, steps_for_timeline_format
)
elif element_name == "indices":
batch = indices
elif element_name == "terminal":
terminal_indices = (indices + steps - 1) % self._replay_capacity
batch = self._store["terminal"][terminal_indices].to(torch.bool)
elif element_name == "reward":
if self._return_as_timeline_format or self._return_everything_as_stack:
batch = self._get_batch_for_indices(
"reward", indices, steps_for_timeline_format
)
else:
batch = self._reduce_multi_step_reward(multistep_indices, steps)
elif element_name == "step":
batch = steps
elif element_name in self._store:
batch = self._get_batch_for_indices(element_name, indices)
elif element_name.startswith("next_"):
store_name = element_name[len("next_") :]
assert (
store_name in self._store
), f"{store_name} is not in {self._store.keys()}"
batch = self._get_batch_for_indices(
store_name, next_indices, steps_for_timeline_format
)
else:
# We assume the other elements are filled in by the subclass.
batch = None
# always enables the batch_size dim
if isinstance(batch, torch.Tensor) and batch.ndim == 1:
batch = batch.unsqueeze(1)
batch_arrays.append(batch)
return self._batch_type(*batch_arrays)