in rlkit/envs/vae_wrapper.py [0:0]
def sample_goals(self, batch_size):
# TODO: make mode a parameter you pass in
if self._goal_sampling_mode == 'custom_goal_sampler':
return self.custom_goal_sampler(batch_size)
elif self._goal_sampling_mode == 'presampled':
idx = np.random.randint(0, self.num_goals_presampled, batch_size)
sampled_goals = {
k: v[idx] for k, v in self._presampled_goals.items()
}
# ensures goals are encoded using latest vae
if 'image_desired_goal' in sampled_goals:
sampled_goals['latent_desired_goal'] = self._encode(sampled_goals['image_desired_goal'])
return sampled_goals
elif self._goal_sampling_mode == 'env':
goals = self.wrapped_env.sample_goals(batch_size)
latent_goals = self._encode(goals[self.vae_input_desired_goal_key])
elif self._goal_sampling_mode == 'reset_of_env':
assert batch_size == 1
goal = self.wrapped_env.get_goal()
goals = {k: v[None] for k, v in goal.items()}
latent_goals = self._encode(
goals[self.vae_input_desired_goal_key]
)
elif self._goal_sampling_mode == 'vae_prior':
goals = {}
latent_goals = self._sample_vae_prior(batch_size)
else:
raise RuntimeError("Invalid: {}".format(self._goal_sampling_mode))
if self._decode_goals:
decoded_goals = self._decode(latent_goals)
else:
decoded_goals = None
image_goals, proprio_goals = self._image_and_proprio_from_decoded(
decoded_goals
)
goals['desired_goal'] = latent_goals
goals['latent_desired_goal'] = latent_goals
if proprio_goals is not None:
goals['proprio_desired_goal'] = proprio_goals
if image_goals is not None:
goals['image_desired_goal'] = image_goals
if decoded_goals is not None:
goals[self.vae_input_desired_goal_key] = decoded_goals
return goals