in habitat_baselines/utils/gym_adapter.py [0:0]
def __init__(self, env, save_orig_obs: bool = False):
self._gym_goal_keys = env._rl_config.get("GYM_DESIRED_GOAL_KEYS", [])
self._gym_achieved_goal_keys = env._rl_config.get(
"GYM_ACHIEVED_GOAL_KEYS", []
)
self._fix_info_dict = env._rl_config.get("GYM_FIX_INFO_DICT", True)
self._gym_action_keys = env._rl_config.get("GYM_ACTION_KEYS", None)
self._gym_obs_keys = env._rl_config.get("GYM_OBS_KEYS", None)
if self._gym_obs_keys is None:
self._gym_obs_keys = list(env.observation_space.spaces.keys())
if self._gym_action_keys is None:
self._gym_action_keys = list(env.action_space.spaces.keys())
action_space = env.action_space
action_space = spaces.Dict(
{
k: v
for k, v in action_space.spaces.items()
if (
(self._gym_action_keys is None)
or (k in self._gym_action_keys)
)
}
)
self._last_obs: Optional[Observations] = None
self.action_mapping = {}
self._save_orig_obs = save_orig_obs
self.orig_obs = None
if len(action_space.spaces) != 1:
raise ValueError(
"Cannot convert this action space, more than one action"
)
self.orig_action_name = list(action_space.spaces.keys())[0]
action_space = action_space.spaces[self.orig_action_name]
if not isinstance(action_space, spaces.Dict):
raise ValueError("Cannot convert this action space")
all_box = True
for sub_space in action_space.spaces.values():
if not isinstance(sub_space, spaces.Box):
all_box = False
break
if not all_box:
raise ValueError("Cannot convert this action space")
start_i = 0
for name, sub_space in action_space.spaces.items():
end_i = start_i + sub_space.shape[0]
self.action_mapping[name] = (start_i, end_i)
start_i = end_i
self.action_space = spaces.Box(
shape=(end_i,), low=-1.0, high=1.0, dtype=np.float32
)
self.observation_space = smash_observation_space(
env.observation_space, self._gym_obs_keys
)
dict_space = {
"observation": self.observation_space,
}
if len(self._gym_goal_keys) > 0:
dict_space["desired_goal"] = smash_observation_space(
env.observation_space, self._gym_goal_keys
)
if len(self._gym_achieved_goal_keys) > 0:
dict_space["achieved_goal"] = smash_observation_space(
env.observation_space, self._gym_achieved_goal_keys
)
if len(dict_space) > 1:
self.observation_space = spaces.Dict(dict_space)
self._env = env