mae_envs/modules/agents.py (96 lines of code) (raw):

import numpy as np from mujoco_worldgen.util.types import store_args from mujoco_worldgen.util.sim_funcs import (qpos_idxs_from_joint_prefix, qvel_idxs_from_joint_prefix) from mujoco_worldgen.transforms import set_geom_attr_transform from mujoco_worldgen.util.rotation import normalize_angles from mae_envs.util.transforms import (add_weld_equality_constraint_transform, set_joint_damping_transform) from mae_envs.modules import EnvModule, rejection_placement, get_size_from_xml from mujoco_worldgen import ObjFromXML class Agents(EnvModule): ''' Add Agents to the environment. Args: n_agents (int): number of agents placement_fn (fn or list of fns): See mae_envs.modules.util:rejection_placement for spec. If list of functions, then it is assumed there is one function given per agent color (tuple or list of tuples): rgba for agent. If list of tuples, then it is assumed there is one color given per agent friction (float): agent friction damp_z (bool): if False, reduce z damping to 1 polar_obs (bool): Give observations about rotation in polar coordinates ''' @store_args def __init__(self, n_agents, placement_fn=None, color=None, friction=None, damp_z=False, polar_obs=True): pass def build_world_step(self, env, floor, floor_size): env.metadata['n_agents'] = self.n_agents successful_placement = True for i in range(self.n_agents): env.metadata.pop(f"agent{i}_initpos", None) for i in range(self.n_agents): obj = ObjFromXML("particle_hinge", name=f"agent{i}") if self.friction is not None: obj.add_transform(set_geom_attr_transform('friction', self.friction)) if self.color is not None: _color = (self.color[i] if isinstance(self.color[0], (list, tuple, np.ndarray)) else self.color) obj.add_transform(set_geom_attr_transform('rgba', _color)) if not self.damp_z: obj.add_transform(set_joint_damping_transform(1, 'tz')) if self.placement_fn is not None: _placement_fn = (self.placement_fn[i] if isinstance(self.placement_fn, list) else self.placement_fn) obj_size = get_size_from_xml(obj) pos, pos_grid = rejection_placement(env, _placement_fn, floor_size, obj_size) if pos is not None: floor.append(obj, placement_xy=pos) # store spawn position in metadata. This allows sampling subsequent agents # close to previous agents env.metadata[f"agent{i}_initpos"] = pos_grid else: successful_placement = False else: floor.append(obj) return successful_placement def modify_sim_step(self, env, sim): # Cache qpos, qvel idxs self.agent_qpos_idxs = np.array([qpos_idxs_from_joint_prefix(sim, f'agent{i}') for i in range(self.n_agents)]) self.agent_qvel_idxs = np.array([qvel_idxs_from_joint_prefix(sim, f'agent{i}') for i in range(self.n_agents)]) env.metadata['agent_geom_idxs'] = [sim.model.geom_name2id(f'agent{i}:agent') for i in range(self.n_agents)] def observation_step(self, env, sim): qpos = sim.data.qpos.copy() qvel = sim.data.qvel.copy() agent_qpos = qpos[self.agent_qpos_idxs] agent_qvel = qvel[self.agent_qvel_idxs] agent_angle = agent_qpos[:, [-1]] - np.pi / 2 # Rotate the angle to match visual front agent_qpos_qvel = np.concatenate([agent_qpos, agent_qvel], -1) polar_angle = np.concatenate([np.cos(agent_angle), np.sin(agent_angle)], -1) if self.polar_obs: agent_qpos = np.concatenate([agent_qpos[:, :-1], polar_angle], -1) agent_angle = normalize_angles(agent_angle) obs = { 'agent_qpos_qvel': agent_qpos_qvel, 'agent_angle': agent_angle, 'agent_pos': agent_qpos[:, :3]} return obs class AgentManipulation(EnvModule): ''' Adding this module is necessary for the grabbing mechanic implemented in GrabObjWrapper (found in mae_envs/wrappers/manipulation.py) to work correctly. ''' @store_args def __init__(self): pass def build_world_step(self, env, floor, floor_size): for i in range(env.n_agents): floor.add_transform(add_weld_equality_constraint_transform( f'agent{i}:gripper', f'agent{i}:particle', 'floor0')) return True def modify_sim_step(self, env, sim): sim.model.eq_active[:] = 0