in MTRF/algorithms/softlearning/environments/adapters/gym_adapter.py [0:0]
def __init__(self,
domain,
task,
*args,
env=None,
normalize=True,
observation_keys=(),
goal_keys=(),
unwrap_time_limit=True,
pixel_wrapper_kwargs=None,
is_metaworld=False,
**kwargs):
assert not args, (
"Gym environments don't support args. Use kwargs instead.")
self.normalize = normalize
self.unwrap_time_limit = unwrap_time_limit
self.is_metaworld = is_metaworld
super(GymAdapter, self).__init__(
domain, task, *args, goal_keys=goal_keys, **kwargs)
if observation_keys:
kwargs.update({"observation_keys": observation_keys})
if env is None:
assert (domain is not None and task is not None), (domain, task)
try:
env_id = f"{domain}-{task}"
if self.is_metaworld:
import metaworld
assert env_id in metaworld.ML1.ENV_NAMES, (
"{} not in list of available metaworld environments".format(env_id))
ml1 = self.ml1 = metaworld.ML1(env_id)
env = ml1.train_classes[env_id]()
task = self.metaworld_task = ml1.train_tasks[0] # always pick the first one for now
env.set_task(task)
else:
env = gym.envs.make(env_id, **kwargs)
except gym.error.UnregisteredEnv:
env_id = f"{domain}{task}"
env = gym.envs.make(env_id, **kwargs)
self._env_kwargs = kwargs
else:
assert not kwargs
assert domain is None and task is None, (domain, task)
if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit:
# Remove the TimeLimit wrapper that sets 'done = True' when
# the time limit specified for each environment has been passed and
# therefore the environment is not Markovian (terminal condition
# depends on time rather than state).
env = env.env
if normalize:
env = NormalizeActionWrapper(env)
if pixel_wrapper_kwargs is not None:
env = PixelObservationWrapper(env, **pixel_wrapper_kwargs)
self._env = env
if isinstance(self._env.observation_space, spaces.Dict):
dict_observation_space = self._env.observation_space
self.observation_keys = (
observation_keys or (*self.observation_space.spaces.keys(), ))
elif isinstance(self._env.observation_space, spaces.Box):
dict_observation_space = spaces.Dict(OrderedDict((
(DEFAULT_OBSERVATION_KEY, self._env.observation_space),
)))
self.observation_keys = (DEFAULT_OBSERVATION_KEY, )
self._observation_space = type(dict_observation_space)([
(name, copy.deepcopy(space))
for name, space in dict_observation_space.spaces.items()
if name in self.observation_keys + self.goal_keys
])
if len(self._env.action_space.shape) > 1:
raise NotImplementedError(
"Shape of the action space ({}) is not flat, make sure to"
" check the implemenation.".format(self._env.action_space))
self._action_space = self._env.action_space