gym_xarm/tasks/base.py (238 lines of code) (raw):

import os import gymnasium as gym import mujoco import numpy as np from gymnasium.envs.mujoco.mujoco_rendering import _ALL_RENDERERS, MujocoRenderer from gymnasium_robotics.utils import mujoco_utils from gym_xarm.tasks import mocap RENDER_MODES = ["rgb_array"] if os.environ.get("MUJOCO_GL") == "glfw": RENDER_MODES.append("human") elif os.environ.get("MUJOCO_GL") not in _ALL_RENDERERS: os.environ["MUJOCO_GL"] = "egl" class Base(gym.Env): """ Superclass for all gym-xarm environments. Args: xml_name (str): name of the xml environment file gripper_rotation (list): initial rotation of the gripper (given as a quaternion) """ metadata = { "render_modes": RENDER_MODES, "render_fps": 25, } n_substeps = 20 initial_qpos = {} _mujoco = mujoco _utils = mujoco_utils def __init__( self, task, obs_type="state", render_mode="rgb_array", gripper_rotation=None, observation_width=84, observation_height=84, visualization_width=680, visualization_height=680, ): # Coordinates if gripper_rotation is None: gripper_rotation = [0, 1, 0, 0] self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32) self.center_of_table = np.array([1.655, 0.3, 0.63625]) self.max_z = 1.2 self.min_z = 0.2 # Observations self.obs_type = obs_type # Rendering self.render_mode = render_mode self.observation_width = observation_width self.observation_height = observation_height self.visualization_width = visualization_width self.visualization_height = visualization_height # Assets self.xml_path = os.path.join(os.path.dirname(__file__), "assets", f"{task}.xml") if not os.path.exists(self.xml_path): raise OSError(f"File {self.xml_path} does not exist") # Initialize sim, spaces & renderers self._initialize_simulation() self.observation_renderer = self._initialize_renderer(renderer_type="observation") self.visualization_renderer = self._initialize_renderer(renderer_type="visualization") self.observation_space = self._initialize_observation_space() self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(self.metadata["action_space"]),)) self.action_padding = np.zeros(4 - len(self.metadata["action_space"]), dtype=np.float32) assert ( int(np.round(1.0 / self.dt)) == self.metadata["render_fps"] ), f'Expected value: {int(np.round(1.0 / self.dt))}, Actual value: {self.metadata["render_fps"]}' if "w" not in self.metadata["action_space"]: self.action_padding[-1] = 1.0 def _initialize_simulation(self): """Initialize MuJoCo simulation data structures mjModel and mjData.""" self.model = self._mujoco.MjModel.from_xml_path(self.xml_path) self.data = self._mujoco.MjData(self.model) self._model_names = self._utils.MujocoModelNames(self.model) self.model.vis.global_.offwidth = self.observation_width self.model.vis.global_.offheight = self.observation_height self._env_setup(initial_qpos=self.initial_qpos) self.initial_time = self.data.time self.initial_qpos = np.copy(self.data.qpos) self.initial_qvel = np.copy(self.data.qvel) def _env_setup(self, initial_qpos): """Initial configuration of the environment. Can be used to configure initial state and extract information from the simulation. """ for name, value in initial_qpos.items(): self.data.set_joint_qpos(name, value) mocap.reset(self.model, self.data) mujoco.mj_forward(self.model, self.data) self._sample_goal() mujoco.mj_forward(self.model, self.data) def _initialize_observation_space(self): image_shape = (self.observation_height, self.observation_width, 3) obs = self.get_obs() if self.obs_type == "state": observation_space = gym.spaces.Box(-1000.0, 1000.0, shape=obs.shape, dtype=np.float64) elif self.obs_type == "pixels": observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8) elif self.obs_type == "pixels_agent_pos": observation_space = gym.spaces.Dict( { "pixels": gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8), "agent_pos": gym.spaces.Box( low=-1000.0, high=1000.0, shape=obs["agent_pos"].shape, dtype=np.float64 ), } ) else: raise ValueError( f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, pixels_agent_pos]" ) return observation_space def _initialize_renderer(self, renderer_type: str): if renderer_type == "observation": model = self.model elif renderer_type == "visualization": # HACK: gymnasium doesn't allow for custom size rendering on-the-fly, so we # initialize another renderer with appropriate size for visualization purposes # see https://gymnasium.farama.org/content/migration-guide/#environment-render from copy import deepcopy model = deepcopy(self.model) model.vis.global_.offwidth = self.visualization_width model.vis.global_.offheight = self.visualization_height else: raise ValueError( f"Unknown render type {renderer_type}. Must be one of [observation, visualization]" ) return MujocoRenderer(model, self.data) @property def dt(self): """Return the timestep of each Gymanisum step.""" return self.n_substeps * self.model.opt.timestep @property def eef(self): return self._utils.get_site_xpos(self.model, self.data, "grasp") - self.center_of_table @property def eef_velp(self): return self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt @property def gripper_angle(self): return self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") @property def robot_state(self): return np.concatenate([self.eef - self.center_of_table, self.gripper_angle]) @property def obj(self): return self._utils.get_site_xpos(self.model, self.data, "object_site") - self.center_of_table @property def obj_rot(self): return self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:] @property def obj_velp(self): return self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt @property def obj_velr(self): return self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt def is_success(self): """Indicates whether or not the achieved goal successfully achieved the desired goal.""" return NotImplementedError() def get_reward(self): raise NotImplementedError() def _sample_goal(self): """Samples a new goal and returns it.""" raise NotImplementedError() def reset( self, *, seed: int | None = None, options: dict | None = None, ): """Reset MuJoCo simulation to initial state. Note: Attempt to reset the simulator. Since we randomize initial conditions, it is possible to get into a state with numerical issues (e.g. due to penetration or Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand). In this case, we just keep randomizing until we eventually achieve a valid initial configuration. Args: seed (optional integer): The seed that is used to initialize the environment's PRNG (`np_random`). Defaults to None. options (optional dictionary): Can be used when `reset` is override for additional information to specify how the environment is reset. Returns: observation (dictionary) : Observation of the initial state. info (dictionary): This dictionary contains auxiliary information complementing ``observation``. It should be analogous to the ``info`` returned by :meth:`step`. """ super().reset(seed=seed) did_reset_sim = False while not did_reset_sim: did_reset_sim = self._reset_sim() observation = self.get_obs() info = {} return observation, info def _reset_sim(self): """Resets a simulation and indicates whether or not it was successful. If a reset was unsuccessful (e.g. if a randomized state caused an error in the simulation), this method should indicate such a failure by returning False. In such a case, this method will be called again to attempt a the reset again. """ self.data.time = self.initial_time self.data.qpos[:] = np.copy(self.initial_qpos) self.data.qvel[:] = np.copy(self.initial_qvel) self._sample_goal() self._mujoco.mj_step(self.model, self.data, nstep=10) return True def get_obs(self): if self.obs_type == "state": return self._get_obs() pixels = self._render(renderer=self.observation_renderer) if self.obs_type == "pixels": return pixels elif self.obs_type == "pixels_agent_pos": return { "pixels": pixels, "agent_pos": self.robot_state, } else: raise ValueError( f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, pixels_agent_pos]" ) def step(self, action): assert action.shape == (4,) assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid" self._apply_action(action) self._mujoco.mj_step(self.model, self.data, nstep=2) self._step_callback() observation = self.get_obs() reward = self.get_reward() terminated = is_success = self.is_success() truncated = False info = {"is_success": is_success} return observation, reward, terminated, truncated, info def _step_callback(self): self._mujoco.mj_forward(self.model, self.data) def _limit_gripper(self, gripper_pos, pos_ctrl): if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15: pos_ctrl[0] = min(pos_ctrl[0], 0) if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3: pos_ctrl[0] = max(pos_ctrl[0], 0) if gripper_pos[1] > self.center_of_table[1] + 0.3: pos_ctrl[1] = min(pos_ctrl[1], 0) if gripper_pos[1] < self.center_of_table[1] - 0.3: pos_ctrl[1] = max(pos_ctrl[1], 0) if gripper_pos[2] > self.max_z: pos_ctrl[2] = min(pos_ctrl[2], 0) if gripper_pos[2] < self.min_z: pos_ctrl[2] = max(pos_ctrl[2], 0) return pos_ctrl def _apply_action(self, action): assert action.shape == (4,) action = action.copy() pos_ctrl, gripper_ctrl = action[:3], action[3] pos_ctrl = self._limit_gripper( self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl ) * (1 / self.n_substeps) gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) mocap.apply_action( self.model, self._model_names, self.data, np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]), ) def _set_gripper(self, gripper_pos, gripper_rotation): self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap2", gripper_pos) self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap2", gripper_rotation) self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0) self.data.qpos[10] = 0.0 self.data.qpos[12] = 0.0 def render(self): return self._render(renderer=self.visualization_renderer) def _render(self, renderer: MujocoRenderer): self._render_callback() render = renderer.render(self.render_mode, camera_name="camera0") return render.copy() if render is not None else None def _render_callback(self): self._mujoco.mj_forward(self.model, self.data) def close(self): """Close contains the code necessary to "clean up" the environment. Terminates any existing WindowViewer instances in the Gymnasium MujocoRenderer. """ if self.observation_renderer is not None: self.observation_renderer.close() if self.visualization_renderer is not None: self.visualization_renderer.close()