in svg/utils.py [0:0]
def make_norm_env(cfg):
if 'gym' in cfg.env_name:
from mbbl.env.env_register import make_env
misc_info = {'reset_type': 'gym'}
if 'gym_pets' in cfg.env_name:
misc_info['pets'] = True
env, meta = make_env(cfg.env_name, rand_seed=cfg.seed, misc_info=misc_info)
env.metadata = env._env.metadata
env.reward_range = env._env.reward_range
env.spec = env._env.spec
env.unwrapped = env._env.unwrapped
# env._configured = env._env._configured
env.close = env._env.close
env = RescaleAction(env, -1., 1.)
# assert np.all(env._env.action_space.high == env._env.action_space.high)
assert not cfg.max_episode_steps
# env.action_space = env._env.action_space
if cfg.env_name == 'gym_fswimmer' or 'gym_pets' in cfg.env_name:
env._max_episode_steps = env.env._env_info['max_length']
else:
env._max_episode_steps = env.env._env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env._env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
if 'gym_pets' in cfg.env_name or cfg.env_name == 'gym_fswimmer':
return env.env._env.seed(seed)
else:
return env.env._env.env.seed(seed)
elif cfg.env_name == 'Humanoid-v2':
env = gym.make('Humanoid-v2')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'pets_cheetah':
from svg.env import register_pets_environments
register_pets_environments()
env = gym.make('PetsCheetah-v0')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'pets_reacher':
from svg.env import register_pets_environments
register_pets_environments()
env = gym.make('PetsReacher-v0')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'pets_pusher':
from svg.env import register_pets_environments
register_pets_environments()
env = gym.make('PetsPusher-v0')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'mbpo_hopper':
env = gym.make('Hopper-v2')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'mbpo_walker2d':
env = gym.make('Walker2d-v2')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
# env.reset_old = env.reset
# env.reset = lambda: env.reset_old()[0]
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'mbpo_ant':
from .env import register_mbpo_environments
register_mbpo_environments()
env = gym.make('AntTruncatedObs-v2')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'mbpo_cheetah':
from svg.env import register_mbpo_environments
register_mbpo_environments()
env = gym.make('HalfCheetah-v2')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
elif cfg.env_name == 'mbpo_humanoid':
from svg.env import register_mbpo_environments
register_mbpo_environments()
env = gym.make('HumanoidTruncatedObs-v2')
env = RescaleAction(env, -1., 1.)
assert not cfg.max_episode_steps
env._max_episode_steps = env.env._max_episode_steps
def render(mode, height, width, camera_id):
frame = env.env.render(mode='rgb_array')
return frame
env.render = render
def set_seed(seed):
return env.env.seed(seed)
else:
assert cfg.env_name.startswith('dmc_')
env = dmc.make(cfg)
if cfg.pixels:
env = FrameStack(env, k=cfg.frame_stack)
def set_seed(seed):
return env.env.env._env.task.random.seed(seed)
else:
def set_seed(seed):
return env.env._env.task.random.seed(seed)
env.set_seed = set_seed
return env