def __init__()

in mae_envs/wrappers/multi_agent.py [0:0]


    def __init__(self, env, keys_self, keys_copy=[], keys_self_matrices=[]):
        super().__init__(env)
        self.keys_self = sorted(keys_self)
        self.keys_copy = sorted(keys_copy)
        self.keys_self_matrices = sorted(keys_self_matrices)
        self.n_agents = self.metadata['n_agents']
        new_spaces = {}
        for k, v in self.observation_space.spaces.items():
            # If obs is a self obs, then we only want to include other agents obs,
            # as we will pass the self obs separately.
            assert len(v.shape) > 1, f'Obs {k} has shape {v.shape}'
            if 'mask' in k and k not in self.keys_self_matrices:
                new_spaces[k] = v
            elif k in self.keys_self_matrices:
                new_spaces[k] = Box(low=v.low[:, 1:], high=v.high[:, 1:], dtype=v.dtype)
            elif k in self.keys_self:
                assert v.shape[0] == self.n_agents, \
                    f"For self obs, obs dim 0 should equal number of agents. {k} has shape {v.shape}"
                obs_shape = (v.shape[0], self.n_agents - 1, v.shape[1])
                lows = np.tile(v.low, self.n_agents - 1).reshape(obs_shape)
                highs = np.tile(v.high, self.n_agents - 1).reshape(obs_shape)
                new_spaces[k] = Box(low=lows, high=highs, dtype=v.dtype)
            elif k in self.keys_copy:
                new_spaces[k] = deepcopy(v)
            else:
                obs_shape = (v.shape[0], self.n_agents, v.shape[1])
                lows = np.tile(v.low, self.n_agents).reshape(obs_shape).transpose((1, 0, 2))
                highs = np.tile(v.high, self.n_agents).reshape(obs_shape).transpose((1, 0, 2))
                new_spaces[k] = Box(low=lows, high=highs, dtype=v.dtype)

        for k in self.keys_self:
            new_spaces[k + '_self'] = self.observation_space.spaces[k]

        self.observation_space = Dict(new_spaces)