mae_envs/wrappers/util.py (170 lines of code) (raw):

import gym from mujoco_py import MujocoException from gym.spaces import Dict, Box import numpy as np from copy import deepcopy import logging def update_obs_space(env, delta): spaces = env.observation_space.spaces.copy() for key, shape in delta.items(): spaces[key] = Box(-np.inf, np.inf, shape, np.float32) return Dict(spaces) class NumpyArrayRewardWrapper(gym.RewardWrapper): """ Convenience wrapper that casts rewards to the multiagent format (numpy array of shape (n_agents,)) """ def __init__(self, env): super().__init__(env) def reward(self, rew): return np.zeros((self.unwrapped.n_agents,)) + rew class DiscretizeActionWrapper(gym.ActionWrapper): ''' Take a Box action and convert it to a MultiDiscrete Action through quantization Args: action_key: (string) action to discretize nbuckets: (int) number of discrete actions per dimension. It should be odd such that actions centered around 0 will have the middle action be 0. ''' def __init__(self, env, action_key, nbuckets=11): super().__init__(env) self.action_key = action_key self.discrete_to_continuous_act_map = [] for i, ac_space in enumerate(self.action_space.spaces[action_key].spaces): assert isinstance(ac_space, Box) action_map = np.array([np.linspace(low, high, nbuckets) for low, high in zip(ac_space.low, ac_space.high)]) _nbuckets = np.ones((len(action_map))) * nbuckets self.action_space.spaces[action_key].spaces[i] = gym.spaces.MultiDiscrete(_nbuckets) self.discrete_to_continuous_act_map.append(action_map) self.discrete_to_continuous_act_map = np.array(self.discrete_to_continuous_act_map) def action(self, action): action = deepcopy(action) ac = action[self.action_key] # helper variables for indexing the discrete-to-continuous action map agent_idxs = np.tile(np.arange(ac.shape[0])[:, None], ac.shape[1]) ac_idxs = np.tile(np.arange(ac.shape[1]), ac.shape[0]).reshape(ac.shape) action[self.action_key] = self.discrete_to_continuous_act_map[agent_idxs, ac_idxs, ac] return action class DiscardMujocoExceptionEpisodes(gym.Wrapper): ''' Catches Mujoco Exceptions. Sends signal to discard Episode. ''' def __init__(self, env): super().__init__(env) self.episode_error = False def step(self, action): assert not self.episode_error, "Won't Continue Episode After Mujoco Exception -- \ Please discard episode and reset. If info['discard_episode'] is True the episode\ should be discarded" try: obs, rew, done, info = self.env.step(action) info['discard_episode'] = False except MujocoException as e: self.episode_error = True # Done is set to False such that rollout workers do not accidently send data in # the event that timelimit is up in the same step as an error occured. obs, rew, done, info = {}, 0.0, False, {'discard_episode': True} logging.info(str(e)) logging.info("Encountered Mujoco Exception During Environment Step.\ Reset Episode Required") return obs, rew, done, info def reset(self): try: obs = self.env.reset() except MujocoException: logging.info("Encountered Mujoco Exception During Environment Reset.\ Trying Reset Again") obs = self.reset() self.episode_error = False return obs class MaskActionWrapper(gym.Wrapper): ''' For a boolean action, sets it to zero given a mask from the previous step. For example you could mask the grab action based on whether you can see the box Args: action_key (string): key in action dictionary to be masked mask_keys (string): keys in observation dictionary with which to mask. The shape of the concatenation of the masks (along the 1st dimension) should exactly match that of action_key ''' def __init__(self, env, action_key, mask_keys): super().__init__(env) self.action_key = action_key self.mask_keys = mask_keys def reset(self): self.prev_obs = self.env.reset() return deepcopy(self.prev_obs) def step(self, action): mask = np.concatenate([self.prev_obs[k] for k in self.mask_keys], -1) action[self.action_key] = np.logical_and(action[self.action_key], mask) self.prev_obs, rew, done, info = self.env.step(action) return deepcopy(self.prev_obs), rew, done, info class AddConstantObservationsWrapper(gym.ObservationWrapper): ''' Adds new constant observations to the environment. Args: new_obs: Dictionary with the new observations. ''' def __init__(self, env, new_obs): super().__init__(env) self.new_obs = new_obs for obs_key in self.new_obs: assert obs_key not in self.observation_space.spaces, ( f'Observation key {obs_key} exists in original observation space') if type(self.new_obs[obs_key]) in [list, tuple]: self.new_obs[obs_key] = np.array(self.new_obs[obs_key]) shape = self.new_obs[obs_key].shape self.observation_space = update_obs_space(self, {obs_key: shape}) def observation(self, obs): for key, val in self.new_obs.items(): obs[key] = val return obs class SpoofEntityWrapper(gym.ObservationWrapper): ''' Add extra entities along entity dimension such that shapes can match between environments with differing number of entities. This is meant to be used after SplitObservations and SelectKeysWrapper. This will also add masks that are 1 except along the new columns (which could be used by fully observed value function) Args: total_n_entities (int): total number of entities after spoofing (including spoofed ones) keys (list): observation keys with which to add entities along the second dimension mask_keys (list): mask keys with which to add columns. ''' def __init__(self, env, total_n_entities, keys, mask_keys): super().__init__(env) self.total_n_entities = total_n_entities self.keys = keys self.mask_keys = mask_keys for key in self.keys + self.mask_keys: shape = list(self.observation_space.spaces[key].shape) shape[1] = total_n_entities self.observation_space = update_obs_space(self, {key: shape}) for key in self.mask_keys: shape = list(self.observation_space.spaces[key].shape) self.observation_space = update_obs_space(self, {key + '_spoof': shape}) def observation(self, obs): for key in self.keys: n_to_spoof = self.total_n_entities - obs[key].shape[1] if n_to_spoof > 0: obs[key] = np.concatenate([obs[key], np.zeros((obs[key].shape[0], n_to_spoof, obs[key].shape[-1]))], 1) for key in self.mask_keys: n_to_spoof = self.total_n_entities - obs[key].shape[1] obs[key + '_spoof'] = np.concatenate([np.ones_like(obs[key]), np.zeros((obs[key].shape[0], n_to_spoof))], -1) if n_to_spoof > 0: obs[key] = np.concatenate([obs[key], np.zeros((obs[key].shape[0], n_to_spoof))], -1) return obs class ConcatenateObsWrapper(gym.ObservationWrapper): ''' Group multiple observations under the same key in the observation dictionary. Args: obs_groups: dict of {key_to_save: [keys to concat]} ''' def __init__(self, env, obs_groups): super().__init__(env) self.obs_groups = obs_groups for key_to_save, keys_to_concat in obs_groups.items(): assert np.all([np.array(self.observation_space.spaces[keys_to_concat[0]].shape[:-1]) == np.array(self.observation_space.spaces[k].shape[:-1]) for k in keys_to_concat]), \ f"Spaces were {[(k, v) for k, v in self.observation_space.spaces.items() if k in keys_to_concat]}" new_last_dim = sum([self.observation_space.spaces[k].shape[-1] for k in keys_to_concat]) new_shape = list(self.observation_space.spaces[keys_to_concat[0]].shape[:-1]) + [new_last_dim] self.observation_space = update_obs_space(self, {key_to_save: new_shape}) def observation(self, obs): for key_to_save, keys_to_concat in self.obs_groups.items(): obs[key_to_save] = np.concatenate([obs[k] for k in keys_to_concat], -1) return obs