in rlkit/data_management/obs_dict_replay_buffer.py [0:0]
def add_path(self, path):
obs = path["observations"]
actions = path["actions"]
rewards = path["rewards"]
next_obs = path["next_observations"]
terminals = path["terminals"]
path_len = len(rewards)
actions = flatten_n(actions)
if isinstance(self.env.action_space, Discrete):
actions = np.eye(self._action_dim)[actions].reshape((-1, self._action_dim))
obs = flatten_dict(obs, self.ob_keys_to_save + self.internal_keys)
next_obs = flatten_dict(next_obs, self.ob_keys_to_save + self.internal_keys)
obs = preprocess_obs_dict(obs)
next_obs = preprocess_obs_dict(next_obs)
if self._top + path_len >= self.max_size:
"""
All of this logic is to handle wrapping the pointer when the
replay buffer gets full.
"""
num_pre_wrap_steps = self.max_size - self._top
# numpy slice
pre_wrap_buffer_slice = np.s_[
self._top:self._top + num_pre_wrap_steps, :
]
pre_wrap_path_slice = np.s_[0:num_pre_wrap_steps, :]
num_post_wrap_steps = path_len - num_pre_wrap_steps
post_wrap_buffer_slice = slice(0, num_post_wrap_steps)
post_wrap_path_slice = slice(num_pre_wrap_steps, path_len)
for buffer_slice, path_slice in [
(pre_wrap_buffer_slice, pre_wrap_path_slice),
(post_wrap_buffer_slice, post_wrap_path_slice),
]:
self._actions[buffer_slice] = actions[path_slice]
self._terminals[buffer_slice] = terminals[path_slice]
for key in self.ob_keys_to_save + self.internal_keys:
self._obs[key][buffer_slice] = obs[key][path_slice]
self._next_obs[key][buffer_slice] = next_obs[key][path_slice]
# Pointers from before the wrap
for i in range(self._top, self.max_size):
self._idx_to_future_obs_idx[i] = np.hstack((
# Pre-wrap indices
np.arange(i, self.max_size),
# Post-wrap indices
np.arange(0, num_post_wrap_steps)
))
# Pointers after the wrap
for i in range(0, num_post_wrap_steps):
self._idx_to_future_obs_idx[i] = np.arange(
i,
num_post_wrap_steps,
)
else:
slc = np.s_[self._top:self._top + path_len, :]
self._actions[slc] = actions
self._terminals[slc] = terminals
for key in self.ob_keys_to_save + self.internal_keys:
self._obs[key][slc] = obs[key]
self._next_obs[key][slc] = next_obs[key]
for i in range(self._top, self._top + path_len):
self._idx_to_future_obs_idx[i] = np.arange(
i, self._top + path_len
)
self._top = (self._top + path_len) % self.max_size
self._size = min(self._size + path_len, self.max_size)