gym-compete/gym_compete/new_envs/agents/agent.py (218 lines of code) (raw):

import xml.etree.ElementTree as ET from gym.spaces import Box import six from ..utils import list_filter import numpy as np class Agent(object): ''' Super class for all agents in a multi-agent mujoco environement Each subclass shoudl implement a move_reward method which are the moving rewards for that agent Each agent can also implement its own action space Over-ride set_observation_space to change observation_space (default is Box based on _get_obs() implementation) After creation, an Env reference should be given calling set_env ''' JNT_NPOS = {0: 7, 1: 4, 2: 1, 3: 1, } def __init__(self, agent_id, xml_path, nagents=2): self.id = agent_id self.scope = 'agent' + str(self.id) self._xml_path = xml_path print("Reading agent XML from:", xml_path) self.xml = ET.parse(xml_path) self.env = None self._env_init = False self.n_agents = nagents def set_env(self, env): self.env = env self._env_init = True self._set_body() self._set_joint() if self.n_agents > 1: self._set_other_joint() self.set_observation_space() self.set_action_space() def set_observation_space(self): obs = self._get_obs self.obs_dim = obs.size high = np.inf * np.ones(self.obs_dim) low = -high self.observation_space = Box(low, high) def set_action_space(self): acts = self.xml.find('actuator') self.action_dim = len(list(acts)) default = self.xml.find('default') range_set = False if default is not None: motor = default.find('motor') if motor is not None: ctrl = motor.get('ctrlrange') if ctrl: clow, chigh = list(map(float, ctrl.split())) high = chigh * np.ones(self.action_dim) low = clow * np.ones(self.action_dim) range_set = True if not range_set: high = np.inf * np.ones(self.action_dim) low = - high for i, motor in enumerate(list(acts)): ctrl = motor.get('ctrlrange') if ctrl: clow, chigh = list(map(float, ctrl.split())) high[i] = chigh low[i] = clow self._low = low self._high = high self.action_space = Box(low, high) # @property # def observation_space(self): # return self.observation_space # # @property # def action_space(self): # return self.action_space def in_scope(self, name): return name.startswith(six.b(self.scope)) def in_agent_scope(self, name, agent_id): return name.startswith(six.b('agent' + str(agent_id))) def _set_body(self): self.body_names = list_filter( lambda x: self.in_scope(x), self.env.model.body_names ) self.body_ids = [self.env.model.body_names.index(body) for body in self.body_names] self.body_dofnum = self.env.model.body_dofnum[self.body_ids] self.nv = self.body_dofnum.sum() self.body_dofadr = self.env.model.body_dofadr[self.body_ids] dof = list_filter(lambda x: x >= 0, self.body_dofadr) self.qvel_start_idx = int(dof[0]) last_dof_body_id = self.body_dofnum.shape[0] - 1 while self.body_dofnum[last_dof_body_id] == 0: last_dof_body_id -= 1 self.qvel_end_idx = int(dof[-1] + self.body_dofnum[last_dof_body_id]) def _set_joint(self): self.join_names = list_filter( lambda x: self.in_scope(x), self.env.model.joint_names ) self.joint_ids = [self.env.model.joint_names.index(body) for body in self.join_names] self.jnt_qposadr = self.env.model.jnt_qposadr[self.joint_ids] self.jnt_type = self.env.model.jnt_type[self.joint_ids] self.jnt_nqpos = [self.JNT_NPOS[int(j)] for j in self.jnt_type] self.nq = sum(self.jnt_nqpos) self.qpos_start_idx = int(self.jnt_qposadr[0]) self.qpos_end_idx = int(self.jnt_qposadr[-1] + self.jnt_nqpos[-1]) # self.jnt_dofadr = self.env.model.jnt_dofadr[self.joint_ids] # dof = list_filter(lambda x: x >= 0, self.jnt_dofadr) # self.qvel_start_idx = int(dof[0]) # last_dof_body_id = self.body_dofnum.shape[0] - 1 # while self.body_dofnum[last_dof_body_id] == 0: # last_dof_body_id -= 1 # self.qvel_end_idx = int(dof[-1] + self.body_dofnum[last_dof_body_id]) def _set_other_joint(self): self._other_qpos_idx = {} for i in range(self.n_agents): if i == self.id: continue other_join_names = list_filter( lambda x: self.in_agent_scope(x, i), self.env.model.joint_names ) other_joint_ids = [self.env.model.joint_names.index(body) for body in other_join_names] other_jnt_qposadr = self.env.model.jnt_qposadr[other_joint_ids] jnt_type = self.env.model.jnt_type[other_joint_ids] jnt_nqpos = [self.JNT_NPOS[int(j)] for j in jnt_type] nq = sum(jnt_nqpos) qpos_start_idx = int(other_jnt_qposadr[0]) qpos_end_idx = int(other_jnt_qposadr[-1] + jnt_nqpos[-1]) assert nq == qpos_end_idx - qpos_start_idx, (i, nq, qpos_start_idx, qpos_end_idx) self._other_qpos_idx[i] = (qpos_start_idx, qpos_end_idx) def get_other_agent_qpos(self): other_qpos = {} for i in range(self.n_agents): if i == self.id: continue startid, endid = self._other_qpos_idx[i] qpos = self.env.model.data.qpos[startid: endid] other_qpos[i] = qpos return other_qpos def before_step(self): raise NotImplementedError def after_step(self): raise NotImplementedError def _get_obs(self): raise NotImplementedError def get_body_com(self, body_name): assert self._env_init, "Env reference is not set" idx = self.body_ids[self.body_names.index(six.b(self.scope + '/' + body_name))] return self.env.model.data.com_subtree[idx] def get_cfrc_ext(self): assert self._env_init, "Env reference is not set" return self.env.model.data.cfrc_ext[self.body_ids] def depricated_get_qpos(self): qpos = np.zeros((self.nq, 1)) cnt = 0 for j, start_idx in enumerate(self.jnt_qposadr): jlen = self.jnt_nqpos[j] qpos[cnt: cnt + jlen] = self.env.model.data.qpos[start_idx: start_idx + jlen] cnt += jlen return qpos def get_qpos(self): ''' Note: this relies on the qpos for one agent being contiguously located this is generally true, use depricated_get_qpos if not ''' return self.env.model.data.qpos[self.qpos_start_idx: self.qpos_end_idx] def get_other_qpos(self): ''' Note: this relies on the qpos for one agent being contiguously located this is generally true, use depricated_get_qpos if not ''' left_part = self.env.model.data.qpos[:self.qpos_start_idx] right_part = self.env.model.data.qpos[self.qpos_end_idx:] return np.concatenate((left_part, right_part), axis=0) def get_qvel(self): ''' Note: this relies on the qvel for one agent being contiguously located this is generally true, follow depricated_get_qpos if not ''' return self.env.model.data.qvel[self.qvel_start_idx: self.qvel_end_idx] def get_qfrc_actuator(self): return self.env.model.data.qfrc_actuator[self.qvel_start_idx: self.qvel_end_idx] def get_cvel(self): return self.env.model.data.cvel[self.body_ids] def get_body_mass(self): return self.env.model.body_mass[self.body_ids] def get_xipos(self): return self.env.model.data.xipos[self.body_ids] def get_cinert(self): return self.env.model.data.cinert[self.body_ids] def get_xmat(self): return self.env.model.data.xmat[self.body_ids] def get_torso_xmat(self): return self.env.model.data.xmat[self.body_ids[self.body_names.index(six.b('agent%d/torso' % self.id))]] # def get_ctrl(self): # return self.env.model.data.ctrl[self.joint_ids] def set_xyz(self, xyz): ''' Set (x, y, z) position of the agent any element can be None ''' assert any(xyz) start = self.qpos_start_idx qpos = self.env.model.data.qpos.flatten().copy() if xyz[0]: qpos[start] = xyz[0] if xyz[1]: qpos[start+1] = xyz[1] if xyz[2]: qpos[start+2] = xyz[2] qvel = self.env.model.data.qvel.flatten() self.env.set_state(qpos, qvel) def set_margin(self, margin): agent_geom_ids = [i for i, name in enumerate(self.env.model.geom_names) if self.in_scope(name)] m = self.env.model.geom_margin.copy() print("Resetting", self.scope, "margins to", margin) m[agent_geom_ids] = margin self.env.model.__setattr__('geom_margin', m) def reached_goal(self): ''' Override this ''' raise NotImplementedError def set_goal(self): ''' Override if needed, this called when initializing the agent and also if goal needs to be changed on reset ''' pass def reset_agent(self): '''Override this''' pass