in level_replay/envs.py [0:0]
def make_lr_venv(num_envs, env_name, seeds, device, **kwargs):
level_sampler = kwargs.get('level_sampler')
level_sampler_args = kwargs.get('level_sampler_args')
ret_normalization = not kwargs.get('no_ret_normalization', False)
if env_name in PROCGEN_ENVS:
num_levels = kwargs.get('num_levels', 1)
start_level = kwargs.get('start_level', 0)
distribution_mode = kwargs.get('distribution_mode', 'easy')
paint_vel_info = kwargs.get('paint_vel_info', False)
venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, \
num_levels=num_levels, start_level=start_level, \
distribution_mode=distribution_mode,
paint_vel_info=paint_vel_info)
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
venv = VecNormalize(venv=venv, ob=False, ret=ret_normalization)
if level_sampler_args:
level_sampler = LevelSampler(
seeds,
venv.observation_space, venv.action_space,
**level_sampler_args)
envs = VecPyTorchProcgen(venv, device, level_sampler=level_sampler)
elif env_name.startswith('MiniGrid'):
venv = VecMinigrid(num_envs=num_envs, env_name=env_name, seeds=seeds)
venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
venv = VecNormalize(venv=venv, ob=False, ret=ret_normalization)
if level_sampler_args:
level_sampler = LevelSampler(
seeds,
venv.observation_space, venv.action_space,
**level_sampler_args)
elif seeds:
level_sampler = LevelSampler(
seeds,
venv.observation_space, venv.action_space,
strategy='random',
)
envs = VecPyTorchMinigrid(venv, device, level_sampler=level_sampler)
else:
raise ValueError(f'Unsupported env {env_name}')
return envs, level_sampler