gym_hil/wrappers/hil_wrappers.py (168 lines of code) (raw):

#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import sys import time import gymnasium as gym import numpy as np from gym_hil.mujoco_gym_env import MAX_GRIPPER_COMMAND DEFAULT_EE_STEP_SIZE = {"x": 0.025, "y": 0.025, "z": 0.025} class GripperPenaltyWrapper(gym.Wrapper): def __init__(self, env, penalty=-0.05): super().__init__(env) self.penalty = penalty self.last_gripper_pos = None def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) self.last_gripper_pos = self.unwrapped.get_gripper_pose() / MAX_GRIPPER_COMMAND return obs, info def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) info["discrete_penalty"] = 0.0 if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( action[-1] > 0.5 and self.last_gripper_pos < 0.1 ): info["discrete_penalty"] = self.penalty self.last_gripper_pos = self.unwrapped.get_gripper_pose() / MAX_GRIPPER_COMMAND return observation, reward, terminated, truncated, info class EEActionWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_step_size, use_gripper=False): super().__init__(env) self.ee_action_step_size = ee_action_step_size self.use_gripper = use_gripper self._ee_step_size = np.array( [ ee_action_step_size["x"], ee_action_step_size["y"], ee_action_step_size["z"], ] ) num_actions = 3 # Initialize action space bounds for the non-gripper case action_space_bounds_min = -np.ones(num_actions) action_space_bounds_max = np.ones(num_actions) if self.use_gripper: action_space_bounds_min = np.concatenate([action_space_bounds_min, [0.0]]) action_space_bounds_max = np.concatenate([action_space_bounds_max, [2.0]]) num_actions += 1 ee_action_space = gym.spaces.Box( low=action_space_bounds_min, high=action_space_bounds_max, shape=(num_actions,), dtype=np.float32, ) self.action_space = ee_action_space def action(self, action): """ Mujoco env is expecting a 7D action space [x, y, z, rx, ry, rz, gripper_open] For the moment we only control the x, y, z, gripper """ # action between -1 and 1, scale to step_size action_xyz = action[:3] * self._ee_step_size # TODO: Extend to enable orientation control actions_orn = np.zeros(3) gripper_open_command = [0.0] if self.use_gripper: # NOTE: Normalize gripper action from [0, 2] -> [-1, 1] gripper_open_command = [action[-1] - 1.0] action = np.concatenate([action_xyz, actions_orn, gripper_open_command]) return action class InputsControlWrapper(gym.Wrapper): """ Wrapper that allows controlling a gym environment with a gamepad. This wrapper intercepts the step method and allows human input via gamepad to override the agent's actions when desired. """ def __init__( self, env, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, use_gripper=False, auto_reset=False, input_threshold=0.001, use_gamepad=True, controller_config_path=None, ): """ Initialize the inputs controller wrapper. Args: env: The environment to wrap x_step_size: Base movement step size for X axis in meters y_step_size: Base movement step size for Y axis in meters z_step_size: Base movement step size for Z axis in meters use_gripper: Whether to use gripper control auto_reset: Whether to auto reset the environment when episode ends input_threshold: Minimum movement delta to consider as active input use_gamepad: Whether to use gamepad or keyboard control controller_config_path: Path to the controller configuration JSON file """ super().__init__(env) from gym_hil.wrappers.intervention_utils import ( GamepadController, GamepadControllerHID, KeyboardController, ) # use HidApi for macos if use_gamepad: if sys.platform == "darwin": self.controller = GamepadControllerHID( x_step_size=x_step_size, y_step_size=y_step_size, z_step_size=z_step_size, ) else: self.controller = GamepadController( x_step_size=x_step_size, y_step_size=y_step_size, z_step_size=z_step_size, config_path=controller_config_path, ) else: self.controller = KeyboardController( x_step_size=x_step_size, y_step_size=y_step_size, z_step_size=z_step_size, ) self.auto_reset = auto_reset self.use_gripper = use_gripper self.input_threshold = input_threshold self.controller.start() def get_gamepad_action(self): """ Get the current action from the gamepad if any input is active. Returns: Tuple of (is_active, action, terminate_episode, success) """ # Update the controller to get fresh inputs self.controller.update() # Get movement deltas from the controller delta_x, delta_y, delta_z = self.controller.get_deltas() intervention_is_active = self.controller.should_intervene() # Create action from gamepad input gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) if self.use_gripper: gripper_command = self.controller.gripper_command() if gripper_command == "open": gamepad_action = np.concatenate([gamepad_action, [2.0]]) elif gripper_command == "close": gamepad_action = np.concatenate([gamepad_action, [0.0]]) else: gamepad_action = np.concatenate([gamepad_action, [1.0]]) # Check episode ending buttons # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None episode_end_status = self.controller.get_episode_end_status() terminate_episode = episode_end_status is not None success = episode_end_status == "success" rerecord_episode = episode_end_status == "rerecord_episode" return ( intervention_is_active, gamepad_action, terminate_episode, success, rerecord_episode, ) def step(self, action): """ Step the environment, using gamepad input to override actions when active. cfg. action: Original action from agent Returns: observation, reward, terminated, truncated, info """ # Get gamepad state and action ( is_intervention, gamepad_action, terminate_episode, success, rerecord_episode, ) = self.get_gamepad_action() # Update episode ending state if requested if terminate_episode: logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") if is_intervention: action = gamepad_action # Step the environment obs, reward, terminated, truncated, info = self.env.step(action) # Add episode ending if requested via gamepad terminated = terminated or truncated or terminate_episode if success: reward = 1.0 logging.info("Episode ended successfully with reward 1.0") info["is_intervention"] = is_intervention action_intervention = action info["action_intervention"] = action_intervention info["rerecord_episode"] = rerecord_episode # If episode ended, reset the state if terminated or truncated: # Add success/failure information to info dict info["next.success"] = success # Auto reset if configured if self.auto_reset: obs, reset_info = self.reset() info.update(reset_info) return obs, reward, terminated, truncated, info def reset(self, **kwargs): """Reset the environment.""" self.controller.reset() return self.env.reset(**kwargs) def close(self): """Clean up resources when environment closes.""" # Stop the controller if hasattr(self, "controller"): self.controller.stop() # Call the parent close method return self.env.close() class ResetDelayWrapper(gym.Wrapper): """ Wrapper that adds a time delay when resetting the environment. This can be useful for adding a pause between episodes to allow for human observation. """ def __init__(self, env, delay_seconds=1.0): """ Initialize the time delay reset wrapper. Args: env: The environment to wrap delay_seconds: The number of seconds to delay during reset """ super().__init__(env) self.delay_seconds = delay_seconds def reset(self, **kwargs): """Reset the environment with a time delay.""" # Add the time delay logging.info(f"Reset delay of {self.delay_seconds} seconds") time.sleep(self.delay_seconds) # Call the parent reset method return self.env.reset(**kwargs)