def sample_transition_batch()

in reagent/replay_memory/circular_replay_buffer.py [0:0]


    def sample_transition_batch(self, batch_size=None, indices=None):
        """Returns a batch of transitions (including any extra contents).
        If get_transition_elements has been overridden and defines elements not
        stored in self._store, None will be returned and it will be
        left to the child class to fill it. For example, for the child class
        PrioritizedReplayBuffer, the contents of the
        sampling_probabilities are stored separately in a sum tree.
        When the transition is terminal next_state_batch has undefined contents.
        NOTE: This transition contains the indices of the sampled elements. These
        are only valid during the call to sample_transition_batch, i.e. they may
        be used by subclasses of this replay buffer but may point to different data
        as soon as sampling is done.
        NOTE: Tensors are reshaped. I.e., state is 2-D unless stack_size > 1.
        Scalar values are returned as (batch_size, 1) instead of (batch_size,).
        Args:
          batch_size: int, number of transitions returned. If None, the default
            batch_size will be used.
          indices: None or Tensor, the indices of every transition in the
            batch. If None, sample the indices uniformly.
        Returns:
          transition_batch: tuple of Tensors with the shape and type as in
            get_transition_elements().
        Raises:
          ValueError: If an element to be sampled is missing from the replay buffer.
        """
        if batch_size is None:
            batch_size = self._batch_size
        if indices is None:
            indices = self.sample_index_batch(batch_size)
        else:
            assert isinstance(
                indices, torch.Tensor
            ), f"Indices {indices} have type {type(indices)} instead of torch.Tensor"
            indices = indices.type(dtype=torch.int64)
        assert len(indices) == batch_size

        # calculate 2d array of indices with size (batch_size, update_horizon)
        # ith row contain the multistep indices starting at indices[i]
        multistep_indices = indices.unsqueeze(1) + torch.arange(self._update_horizon)
        multistep_indices %= self._replay_capacity

        steps = self._get_steps(multistep_indices)

        # to pass in to next_features and reward to toggle whether to return
        # a list batch of length steps.
        if self._return_as_timeline_format:
            next_indices = (indices + 1) % self._replay_capacity
            steps_for_timeline_format = steps
        else:
            next_indices = (indices + steps) % self._replay_capacity
            steps_for_timeline_format = None

        batch_arrays = []
        for element_name in self._transition_elements:
            if element_name == "state":
                batch = self._get_batch_for_indices("observation", indices)
            elif element_name == "next_state":
                batch = self._get_batch_for_indices(
                    "observation", next_indices, steps_for_timeline_format
                )
            elif element_name == "indices":
                batch = indices
            elif element_name == "terminal":
                terminal_indices = (indices + steps - 1) % self._replay_capacity
                batch = self._store["terminal"][terminal_indices].to(torch.bool)
            elif element_name == "reward":
                if self._return_as_timeline_format or self._return_everything_as_stack:
                    batch = self._get_batch_for_indices(
                        "reward", indices, steps_for_timeline_format
                    )
                else:
                    batch = self._reduce_multi_step_reward(multistep_indices, steps)
            elif element_name == "step":
                batch = steps
            elif element_name in self._store:
                batch = self._get_batch_for_indices(element_name, indices)
            elif element_name.startswith("next_"):
                store_name = element_name[len("next_") :]
                assert (
                    store_name in self._store
                ), f"{store_name} is not in {self._store.keys()}"
                batch = self._get_batch_for_indices(
                    store_name, next_indices, steps_for_timeline_format
                )
            else:
                # We assume the other elements are filled in by the subclass.
                batch = None

            # always enables the batch_size dim
            if isinstance(batch, torch.Tensor) and batch.ndim == 1:
                batch = batch.unsqueeze(1)
            batch_arrays.append(batch)
        return self._batch_type(*batch_arrays)