in rlkit/launchers/skewfit_experiments.py [0:0]
def get_envs(variant):
from multiworld.core.image_env import ImageEnv
from rlkit.envs.vae_wrapper import VAEWrappedEnv
from rlkit.util.io import load_local_or_remote_file
render = variant.get('render', False)
vae_path = variant.get("vae_path", None)
reward_params = variant.get("reward_params", dict())
init_camera = variant.get("init_camera", None)
do_state_exp = variant.get("do_state_exp", False)
presample_goals = variant.get('presample_goals', False)
presample_image_goals_only = variant.get('presample_image_goals_only',
False)
presampled_goals_path = variant.get('presampled_goals_path', None)
vae = load_local_or_remote_file(vae_path) if type(
vae_path) is str else vae_path
if 'env_id' in variant:
import gym
import multiworld
multiworld.register_all_envs()
env = gym.make(variant['env_id'])
else:
env = variant["env_class"](**variant['env_kwargs'])
if not do_state_exp:
if isinstance(env, ImageEnv):
image_env = env
else:
image_env = ImageEnv(
env,
variant.get('imsize'),
init_camera=init_camera,
transpose=True,
normalize=True,
)
if presample_goals:
"""
This will fail for online-parallel as presampled_goals will not be
serialized. Also don't use this for online-vae.
"""
if presampled_goals_path is None:
image_env.non_presampled_goal_img_is_garbage = True
vae_env = VAEWrappedEnv(
image_env,
vae,
imsize=image_env.imsize,
decode_goals=render,
render_goals=render,
render_rollouts=render,
reward_params=reward_params,
**variant.get('vae_wrapped_env_kwargs', {})
)
presampled_goals = variant['generate_goal_dataset_fctn'](
env=vae_env,
env_id=variant.get('env_id', None),
**variant['goal_generation_kwargs']
)
del vae_env
else:
presampled_goals = load_local_or_remote_file(
presampled_goals_path
).item()
del image_env
image_env = ImageEnv(
env,
variant.get('imsize'),
init_camera=init_camera,
transpose=True,
normalize=True,
presampled_goals=presampled_goals,
**variant.get('image_env_kwargs', {})
)
vae_env = VAEWrappedEnv(
image_env,
vae,
imsize=image_env.imsize,
decode_goals=render,
render_goals=render,
render_rollouts=render,
reward_params=reward_params,
presampled_goals=presampled_goals,
**variant.get('vae_wrapped_env_kwargs', {})
)
print("Presampling all goals only")
else:
vae_env = VAEWrappedEnv(
image_env,
vae,
imsize=image_env.imsize,
decode_goals=render,
render_goals=render,
render_rollouts=render,
reward_params=reward_params,
**variant.get('vae_wrapped_env_kwargs', {})
)
if presample_image_goals_only:
presampled_goals = variant['generate_goal_dataset_fctn'](
image_env=vae_env.wrapped_env,
**variant['goal_generation_kwargs']
)
image_env.set_presampled_goals(presampled_goals)
print("Presampling image goals only")
else:
print("Not using presampled goals")
env = vae_env
return env