in reagent/replay_memory/circular_replay_buffer.py [0:0]
def add(self, **kwargs):
"""Adds a transition to the replay memory.
This function checks the types and handles the padding at the beginning of
an episode. Then it calls the _add function.
Since the next_observation in the transition will be the observation added
next there is no need to pass it.
If the replay memory is at capacity the oldest transition will be discarded.
Only accept kwargs, which must contain observation, action, reward, terminal
as keys.
"""
if not self._initialized_buffer:
self.initialize_buffer(**kwargs)
self._check_add_types(**kwargs)
last_idx = (self.cursor() - 1) % self._replay_capacity
if self.is_empty() or self._store["terminal"][last_idx]:
self._num_transitions_in_current_episode = 0
for _ in range(self._stack_size - 1):
# Child classes can rely on the padding transitions being filled with
# zeros. This is useful when there is a priority argument.
self._add_zero_transition()
# remember, the last update_horizon transitions shouldn't be sampled
cur_idx = self.cursor()
self.set_index_valid_status(idx=cur_idx, is_valid=False)
if self._num_transitions_in_current_episode >= self._update_horizon:
idx = (cur_idx - self._update_horizon) % self._replay_capacity
self.set_index_valid_status(idx=idx, is_valid=True)
self._add(**kwargs)
self._num_transitions_in_current_episode += 1
# mark the next stack_size-1 as invalid (note cursor has advanced by 1)
for i in range(self._stack_size - 1):
idx = (self.cursor() + i) % self._replay_capacity
self.set_index_valid_status(idx=idx, is_valid=False)
if kwargs["terminal"]:
# Since the frame (cur_idx) we just inserted was terminal, we now mark
# the last "num_back" transitions as valid for sampling (including cur_idx).
# This is because next_state is not relevant for those terminal (multi-step)
# transitions.
# NOTE: this was not accounted for by the original Dopamine buffer.
# It is not a big problem, since after update_horizon steps,
# the original Dopamine buffer will make these frames
# available for sampling.
# But that is update_horizon steps too late. If we train right
# after an episode terminates, this can result in missing the
# bulk of rewards at the end of the most recent episode.
num_back = min(
self._num_transitions_in_current_episode, self._update_horizon
)
for i in range(0, num_back):
idx = (cur_idx - i) % self._replay_capacity
self.set_index_valid_status(idx=idx, is_valid=True)