mae_envs/envs/base.py (103 lines of code) (raw):
import numpy as np
import logging
from mujoco_worldgen import Floor, WorldBuilder, WorldParams, Env
from mae_envs.wrappers.multi_agent import (SplitMultiAgentActions, SplitObservations,
SelectKeysWrapper)
from mae_envs.wrappers.util import DiscretizeActionWrapper, DiscardMujocoExceptionEpisodes
from mae_envs.wrappers.line_of_sight import AgentAgentObsMask2D
from mae_envs.modules.agents import Agents
from mae_envs.modules.walls import RandomWalls
from mae_envs.modules.objects import Boxes, Ramps
class Base(Env):
'''
Multi-agent Base Environment.
Args:
horizon (int): Number of steps agent gets to act
n_substeps (int): Number of internal mujoco steps per outer environment step;
essentially this is action repeat.
n_agents (int): number of agents in the environment
floor_size (float or (float, float)): size of the floor. If a list of 2 floats, the floorsize
will be randomized between them on each episode
grid_size (int): size of the grid that we'll use to place objects on the floor
action_lims (float tuple): lower and upper limit of mujoco actions
deterministic_mode (bool): if True, seeds are incremented rather than randomly sampled.
'''
def __init__(self, horizon=250, n_substeps=5, n_agents=2,
floor_size=6., grid_size=30,
action_lims=(-1.0, 1.0), deterministic_mode=False,
**kwargs):
super().__init__(get_sim=self._get_sim,
get_obs=self._get_obs,
action_space=tuple(action_lims),
horizon=horizon,
deterministic_mode=deterministic_mode)
self.n_agents = n_agents
self.metadata = {}
self.metadata['n_actors'] = n_agents
self.horizon = horizon
self.n_substeps = n_substeps
if not isinstance(floor_size, (tuple, list, np.ndarray)):
self.floor_size_dist = [floor_size, floor_size]
else:
self.floor_size_dist = floor_size
self.grid_size = grid_size
self.kwargs = kwargs
self.placement_grid = np.zeros((grid_size, grid_size))
self.modules = []
def add_module(self, module):
self.modules.append(module)
def _get_obs(self, sim):
'''
Loops through modules, calls their observation_step functions, and
adds the result to the observation dictionary.
'''
obs = {}
for module in self.modules:
obs.update(module.observation_step(self, self.sim))
return obs
def _get_sim(self, seed):
'''
Calls build_world_step and then modify_sim_step for each module. If
a build_world_step failed, then restarts.
'''
self.floor_size = np.random.uniform(self.floor_size_dist[0], self.floor_size_dist[1])
self.metadata['floor_size'] = self.floor_size
world_params = WorldParams(size=(self.floor_size, self.floor_size, 2.5),
num_substeps=self.n_substeps)
successful_placement = False
failures = 0
while not successful_placement:
if (failures + 1) % 10 == 0:
logging.warning(f"Failed {failures} times in creating environment")
builder = WorldBuilder(world_params, seed)
floor = Floor()
builder.append(floor)
self.placement_grid = np.zeros((self.grid_size, self.grid_size))
successful_placement = np.all([module.build_world_step(self, floor, self.floor_size)
for module in self.modules])
failures += 1
sim = builder.get_sim()
for module in self.modules:
module.modify_sim_step(self, sim)
return sim
def make_env(n_substeps=5, horizon=250, deterministic_mode=False, n_agents=2,
n_boxes=2, n_ramps=1):
'''
This make_env function is not used anywhere; it exists to provide a simple, bare-bones
example of how to construct a multi-agent environment using the modules framework.
'''
env = Base(n_agents=n_agents, n_substeps=n_substeps, horizon=horizon,
deterministic_mode=deterministic_mode)
env.add_module(RandomWalls(grid_size=30, num_rooms=4, min_room_size=6, door_size=2))
if n_boxes > 0:
env.add_module(Boxes(n_boxes=n_boxes))
if n_ramps > 0:
env.add_module(Ramps(n_ramps=n_ramps))
env.add_module(Agents(n_agents))
env.reset()
keys_self = ['agent_qpos_qvel']
keys_mask_self = ['mask_aa_obs']
keys_external = ['agent_qpos_qvel']
keys_mask_external = []
env = SplitMultiAgentActions(env)
env = DiscretizeActionWrapper(env, 'action_movement')
env = AgentAgentObsMask2D(env)
env = SplitObservations(env, keys_self + keys_mask_self)
env = SelectKeysWrapper(env, keys_self=keys_self,
keys_other=keys_external + keys_mask_self + keys_mask_external)
env = DiscardMujocoExceptionEpisodes(env)
return env