in rlkit/data_management/online_vae_replay_buffer.py [0:0]
def refresh_latents(self, epoch):
self.epoch = epoch
self.skew = (self.epoch > self.start_skew_epoch)
batch_size = 512
next_idx = min(batch_size, self._size)
if self.exploration_rewards_type == 'hash_count':
# you have to count everything then compute exploration rewards
cur_idx = 0
next_idx = min(batch_size, self._size)
while cur_idx < self._size:
idxs = np.arange(cur_idx, next_idx)
normalized_imgs = (
normalize_image(self._next_obs[self.decoded_obs_key][idxs])
)
cur_idx = next_idx
next_idx += batch_size
next_idx = min(next_idx, self._size)
cur_idx = 0
obs_sum = np.zeros(self.vae.representation_size)
obs_square_sum = np.zeros(self.vae.representation_size)
while cur_idx < self._size:
idxs = np.arange(cur_idx, next_idx)
self._obs[self.observation_key][idxs] = \
self.env._encode(
normalize_image(self._obs[self.decoded_obs_key][idxs])
)
self._next_obs[self.observation_key][idxs] = \
self.env._encode(
normalize_image(self._next_obs[self.decoded_obs_key][idxs])
)
# WARNING: we only refresh the desired/achieved latents for
# "next_obs". This means that obs[desired/achieve] will be invalid,
# so make sure there's no code that references this.
# TODO: enforce this with code and not a comment
self._next_obs[self.desired_goal_key][idxs] = \
self.env._encode(
normalize_image(self._next_obs[self.decoded_desired_goal_key][idxs])
)
self._next_obs[self.achieved_goal_key][idxs] = \
self.env._encode(
normalize_image(self._next_obs[self.decoded_achieved_goal_key][idxs])
)
normalized_imgs = (
normalize_image(self._next_obs[self.decoded_obs_key][idxs])
)
if self._give_explr_reward_bonus:
rewards = self.exploration_reward_func(
normalized_imgs,
idxs,
**self.priority_function_kwargs
)
self._exploration_rewards[idxs] = rewards.reshape(-1, 1)
if self._prioritize_vae_samples:
if (
self.exploration_rewards_type == self.vae_priority_type
and self._give_explr_reward_bonus
):
self._vae_sample_priorities[idxs] = (
self._exploration_rewards[idxs]
)
else:
self._vae_sample_priorities[idxs] = (
self.vae_prioritization_func(
normalized_imgs,
idxs,
**self.priority_function_kwargs
).reshape(-1, 1)
)
obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0)
obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0)
cur_idx = next_idx
next_idx += batch_size
next_idx = min(next_idx, self._size)
self.vae.dist_mu = obs_sum/self._size
self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2))
if self._prioritize_vae_samples:
"""
priority^power is calculated in the priority function
for image_bernoulli_prob or image_gaussian_inv_prob and
directly here if not.
"""
if self.vae_priority_type == 'vae_prob':
self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs(
self._vae_sample_priorities[:self._size]
)
self._vae_sample_probs = self._vae_sample_priorities[:self._size]
else:
self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power
p_sum = np.sum(self._vae_sample_probs)
assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
self._vae_sample_probs /= np.sum(self._vae_sample_probs)
self._vae_sample_probs = self._vae_sample_probs.flatten()