mae_envs/modules/construction_sites.py (96 lines of code) (raw):

import numpy as np from mujoco_worldgen.util.types import store_args from mae_envs.modules import EnvModule, rejection_placement class ConstructionSites(EnvModule): ''' Adds construction sites to the environment. A construction site consists of 5 regular mujoco sites, with four of them (the 'corner' sites) forming a rectangle and the last site being placed in the center of the rectangle. Args: n_sites (int or (int, int)): Number of construction sites. If tuple of ints, every episode the number of sites is drawn uniformly from range(n_sites[0], n_sites[1] + 1) placement_fn (fn or list of fns): See mae_envs.modules.util:rejection_placement for spec If list of functions, then it is assumed there is one function given per agent site_name (str): Name for the sites. site_size (float): Site size site_height (float): Site height n_elongated_sites (int or (int, int)): Number of elongated sites. If tuple of ints, every episode the number of elongated sites is drawn uniformly from range(n_elongated_sites[0], n_elongated_sited[1] + 1) ''' @store_args def __init__(self, n_sites, placement_fn=None, site_name='construction_site', site_size=0.5, site_height=0.25, n_elongated_sites=0): if type(n_sites) not in [tuple, list, np.ndarray]: self.n_sites = [n_sites, n_sites] if type(n_elongated_sites) not in [tuple, list, np.ndarray]: self.n_elongated_sites = [n_elongated_sites, n_elongated_sites] def _mark_site_square(self, floor, floor_size, site_name, site_relative_xyz, site_dims): x, y, z = site_relative_xyz floor.mark(site_name, relative_xyz=(x, y, z), rgba=[1., 1., 1., 1.], size=0.1) corner_rel_offset_x, corner_rel_offset_y = (site_dims / floor_size) / 2 corner_rel_xy = [[x - corner_rel_offset_x, y - corner_rel_offset_y], [x - corner_rel_offset_x, y + corner_rel_offset_y], [x + corner_rel_offset_x, y - corner_rel_offset_y], [x + corner_rel_offset_x, y + corner_rel_offset_y]] for i, (x_corner, y_corner) in enumerate(corner_rel_xy): floor.mark(f'{site_name}_corner{i}', relative_xyz=(x_corner, y_corner, z), size=0.05, rgba=[0.8, 0.8, 0.8, 1.]) def build_world_step(self, env, floor, floor_size): self.curr_n_sites = env._random_state.randint(self.n_sites[0], self.n_sites[1] + 1) self.curr_n_elongated_sites = env._random_state.randint( self.n_elongated_sites[0], self.n_elongated_sites[1] + 1) env.metadata['curr_n_sites'] = self.curr_n_sites env.metadata['curr_n_elongated_sites'] = self.curr_n_elongated_sites self.site_size_array = self.site_size * np.ones((self.curr_n_sites, 2)) if self.curr_n_elongated_sites > 0: n_xaligned = env._random_state.randint(self.curr_n_elongated_sites + 1) self.site_size_array[:n_xaligned, :] = self.site_size * np.array([3.3, 0.3]) self.site_size_array[n_xaligned:self.curr_n_elongated_sites, :] = ( self.site_size * np.array([0.3, 3.3])) successful_placement = True for i in range(self.curr_n_sites): if self.placement_fn is not None: _placement_fn = (self.placement_fn[i] if isinstance(self.placement_fn, list) else self.placement_fn) pos, _ = rejection_placement(env, _placement_fn, floor_size, self.site_size_array[i]) if pos is not None: self._mark_site_square(floor, floor_size, f'{self.site_name}{i}', (pos[0], pos[1], self.site_height), self.site_size_array[i]) else: successful_placement = False else: # place the site so that all the corners are still within the play area pos_min = self.site_size_array[i].max() / (floor_size * 1.1) / 2 pos = env._random_state.uniform(pos_min, 1 - pos_min, 2) self._mark_site_square(floor, floor_size, f'{self.site_name}{i}', (pos[0], pos[1], self.site_height), self.site_size_array[i]) return successful_placement def modify_sim_step(self, env, sim): self.construction_site_idxs = np.array( [sim.model.site_name2id(f'{self.site_name}{i}') for i in range(self.curr_n_sites)] ) self.construction_site_corner_idxs = np.array( [sim.model.site_name2id(f'{self.site_name}{i}_corner{j}') for i in range(self.curr_n_sites) for j in range(4)] ) def observation_step(self, env, sim): site_pos = sim.data.site_xpos[self.construction_site_idxs] site_corner_pos = sim.data.site_xpos[self.construction_site_corner_idxs] site_obs = np.concatenate((site_pos, site_corner_pos.reshape((self.curr_n_sites, 12))), axis=-1) mask_site_obs = np.ones((env.n_agents, self.curr_n_sites)) obs = {'construction_site_pos': site_pos, 'construction_site_corner_pos': site_corner_pos, 'construction_site_obs': site_obs, 'mask_acs_obs': mask_site_obs} return obs