gym_hil/mujoco_gym_env.py (192 lines of code) (raw):

#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from pathlib import Path from typing import Literal, Optional import gymnasium as gym import mujoco import numpy as np from gymnasium import spaces from gym_hil.controllers import opspace MAX_GRIPPER_COMMAND = 255 @dataclass(frozen=True) class GymRenderingSpec: height: int = 128 width: int = 128 camera_id: str | int = -1 mode: Literal["rgb_array", "human"] = "rgb_array" class MujocoGymEnv(gym.Env): """MujocoEnv with gym interface.""" def __init__( self, xml_path: Path, seed: int = 0, control_dt: float = 0.02, physics_dt: float = 0.002, render_spec: GymRenderingSpec = GymRenderingSpec(), # noqa: B008 ): self._model = mujoco.MjModel.from_xml_path(xml_path.as_posix()) self._model.vis.global_.offwidth = render_spec.width self._model.vis.global_.offheight = render_spec.height self._data = mujoco.MjData(self._model) self._model.opt.timestep = physics_dt self._control_dt = control_dt self._n_substeps = int(control_dt // physics_dt) self._random = np.random.RandomState(seed) self._viewer: Optional[mujoco.Renderer] = None self._render_specs = render_spec def render(self): if self._viewer is None: self._viewer = mujoco.Renderer( model=self._model, height=self._render_specs.height, width=self._render_specs.width, ) self._viewer.update_scene(self._data, camera=self._render_specs.camera_id) return self._viewer.render() def close(self) -> None: """Release graphics resources if they exist. In MuJoCo < 2.3.0 `mujoco.Renderer` had no `close()` member. Calling it unconditionally therefore raises `AttributeError`. We check for the attribute first and fall back to a no-op, keeping compatibility across MuJoCo versions. """ viewer = self._viewer if viewer is None: return if hasattr(viewer, "close") and callable(viewer.close): try: # noqa: SIM105 viewer.close() except Exception: # Ignore errors coming from already freed OpenGL contexts or # older MuJoCo builds. pass self._viewer = None # Accessors. @property def model(self) -> mujoco.MjModel: return self._model @property def data(self) -> mujoco.MjData: return self._data @property def control_dt(self) -> float: return self._control_dt @property def physics_dt(self) -> float: return self._model.opt.timestep @property def random_state(self) -> np.random.RandomState: return self._random class FrankaGymEnv(MujocoGymEnv): """Base class for Franka Panda robot environments.""" def __init__( self, xml_path: Path | None = None, seed: int = 0, control_dt: float = 0.02, physics_dt: float = 0.002, render_spec: GymRenderingSpec = GymRenderingSpec(), # noqa: B008 render_mode: Literal["rgb_array", "human"] = "rgb_array", image_obs: bool = False, home_position: np.ndarray = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)), # noqa: B008 cartesian_bounds: np.ndarray = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]), # noqa: B008 ): if xml_path is None: xml_path = Path(__file__).parent.parent / "gym_hil" / "assets" / "scene.xml" super().__init__( xml_path=xml_path, seed=seed, control_dt=control_dt, physics_dt=physics_dt, render_spec=render_spec, ) self._home_position = home_position self._cartesian_bounds = cartesian_bounds self.metadata = { "render_modes": ["human", "rgb_array"], "render_fps": int(np.round(1.0 / self.control_dt)), } self.render_mode = render_mode self.image_obs = image_obs # Setup cameras camera_name_1 = "front" camera_name_2 = "handcam_rgb" camera_id_1 = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_CAMERA, camera_name_1) camera_id_2 = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_CAMERA, camera_name_2) self.camera_id = (camera_id_1, camera_id_2) # Cache robot IDs self._panda_dof_ids = np.asarray([self._model.joint(f"joint{i}").id for i in range(1, 8)]) self._panda_ctrl_ids = np.asarray([self._model.actuator(f"actuator{i}").id for i in range(1, 8)]) self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id self._pinch_site_id = self._model.site("pinch").id # Setup observation and action spaces self._setup_observation_space() self._setup_action_space() # Initialize renderer self._viewer = mujoco.Renderer(self.model, height=render_spec.height, width=render_spec.width) self._viewer.render() def _setup_observation_space(self): """Setup the observation space for the Franka environment.""" base_obs_space = { "agent_pos": spaces.Dict( { "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(7,), dtype=np.float32), "tcp_vel": spaces.Box(-np.inf, np.inf, shape=(6,), dtype=np.float32), "gripper_pose": spaces.Box(-1, 1, shape=(1,), dtype=np.float32), } ) } self.observation_space = spaces.Dict(base_obs_space) if self.image_obs: self.observation_space = spaces.Dict( { **base_obs_space, "pixels": spaces.Dict( { "front": spaces.Box( low=0, high=255, shape=(self._render_specs.height, self._render_specs.width, 3), dtype=np.uint8, ), "wrist": spaces.Box( low=0, high=255, shape=(self._render_specs.height, self._render_specs.width, 3), dtype=np.uint8, ), } ), } ) def _setup_action_space(self): """Setup the action space for the Franka environment.""" self.action_space = spaces.Box( low=np.asarray([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], dtype=np.float32), high=np.asarray([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32), dtype=np.float32, ) def reset_robot(self): """Reset the robot to home position.""" self._data.qpos[self._panda_dof_ids] = self._home_position self._data.ctrl[self._panda_ctrl_ids] = 0.0 mujoco.mj_forward(self._model, self._data) # Reset mocap body to home position tcp_pos = self._data.sensor("2f85/pinch_pos").data self._data.mocap_pos[0] = tcp_pos def apply_action(self, action): """Apply the action to the robot.""" x, y, z, rx, ry, rz, grasp_command = action # Set the mocap position pos = self._data.mocap_pos[0].copy() dpos = np.asarray([x, y, z]) npos = np.clip(pos + dpos, *self._cartesian_bounds) self._data.mocap_pos[0] = npos # Set gripper grasp g = self._data.ctrl[self._gripper_ctrl_id] / MAX_GRIPPER_COMMAND ng = np.clip(g + grasp_command, 0.0, 1.0) self._data.ctrl[self._gripper_ctrl_id] = ng * MAX_GRIPPER_COMMAND # Apply operational space control for _ in range(self._n_substeps): tau = opspace( model=self._model, data=self._data, site_id=self._pinch_site_id, dof_ids=self._panda_dof_ids, pos=self._data.mocap_pos[0], ori=self._data.mocap_quat[0], joint=self._home_position, gravity_comp=True, ) self._data.ctrl[self._panda_ctrl_ids] = tau mujoco.mj_step(self._model, self._data) def get_robot_state(self): """Get the current state of the robot.""" tcp_pos = self._data.sensor("2f85/pinch_pos").data # tcp_quat = self._data.sensor("2f85/pinch_quat").data # tcp_vel = self._data.sensor("2f85/pinch_vel").data # tcp_angvel = self._data.sensor("2f85/pinch_angvel").data qpos = self.data.qpos[self._panda_dof_ids].astype(np.float32) qvel = self.data.qvel[self._panda_dof_ids].astype(np.float32) gripper_pose = self.get_gripper_pose() return np.concatenate([qpos, qvel, gripper_pose, tcp_pos]) def render(self): """Render the environment and return frames from multiple cameras.""" rendered_frames = [] for cam_id in self.camera_id: self._viewer.update_scene(self.data, camera=cam_id) rendered_frames.append(self._viewer.render()) return rendered_frames def get_gripper_pose(self): """Get the current pose of the gripper.""" return np.array([self._data.ctrl[self._gripper_ctrl_id]], dtype=np.float32)