gym_hil/wrappers/intervention_utils.py (349 lines of code) (raw):

#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from pathlib import Path def load_controller_config(controller_name: str, config_path: str | None = None) -> dict: """ Load controller configuration from a JSON file. Args: controller_name: Name of the controller to load. config_path: Path to the config file. If None, uses the package's default config. Returns: Dictionary containing the selected controller's configuration. """ if config_path is None: config_path = Path(__file__).parent.parent / "controller_config.json" with open(config_path) as f: config = json.load(f) controller_config = config[controller_name] if controller_name in config else config["default"] if controller_name not in config: print(f"Controller {controller_name} not found in config. Using default configuration.") return controller_config class InputController: """Base class for input controllers that generate motion deltas.""" def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01): """ Initialize the controller. Args: x_step_size: Base movement step size in meters y_step_size: Base movement step size in meters z_step_size: Base movement step size in meters """ self.x_step_size = x_step_size self.y_step_size = y_step_size self.z_step_size = z_step_size self.running = True self.episode_end_status = None # None, "success", or "failure" self.intervention_flag = False self.open_gripper_command = False self.close_gripper_command = False def start(self): """Start the controller and initialize resources.""" pass def stop(self): """Stop the controller and release resources.""" pass def reset(self): """Reset the controller.""" pass def get_deltas(self): """Get the current movement deltas (dx, dy, dz) in meters.""" return 0.0, 0.0, 0.0 def update(self): """Update controller state - call this once per frame.""" pass def __enter__(self): """Support for use in 'with' statements.""" self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): """Ensure resources are released when exiting 'with' block.""" self.stop() def get_episode_end_status(self): """ Get the current episode end status. Returns: None if episode should continue, "success" or "failure" otherwise """ status = self.episode_end_status self.episode_end_status = None # Reset after reading return status def should_intervene(self): """Return True if intervention flag was set.""" return self.intervention_flag def gripper_command(self): """Return the current gripper command.""" if self.open_gripper_command == self.close_gripper_command: return "no-op" elif self.open_gripper_command: return "open" elif self.close_gripper_command: return "close" class KeyboardController(InputController): """Generate motion deltas from keyboard input.""" def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01): super().__init__(x_step_size, y_step_size, z_step_size) self.key_states = { "forward_x": False, "backward_x": False, "forward_y": False, "backward_y": False, "forward_z": False, "backward_z": False, "success": False, "failure": False, "intervention": False, "rerecord": False, } self.listener = None def start(self): """Start the keyboard listener.""" from pynput import keyboard def on_press(key): try: if key == keyboard.Key.up: self.key_states["forward_x"] = True elif key == keyboard.Key.down: self.key_states["backward_x"] = True elif key == keyboard.Key.left: self.key_states["forward_y"] = True elif key == keyboard.Key.right: self.key_states["backward_y"] = True elif key == keyboard.Key.shift: self.key_states["backward_z"] = True elif key == keyboard.Key.shift_r: self.key_states["forward_z"] = True elif key == keyboard.Key.ctrl_r: self.open_gripper_command = True elif key == keyboard.Key.ctrl_l: self.close_gripper_command = True elif key == keyboard.Key.enter: self.key_states["success"] = True self.episode_end_status = "success" elif key == keyboard.Key.esc: self.key_states["failure"] = True self.episode_end_status = "failure" elif key == keyboard.Key.space: self.key_states["intervention"] = not self.key_states["intervention"] elif key == keyboard.Key.r: self.key_states["rerecord"] = True except AttributeError: pass def on_release(key): try: if key == keyboard.Key.up: self.key_states["forward_x"] = False elif key == keyboard.Key.down: self.key_states["backward_x"] = False elif key == keyboard.Key.left: self.key_states["forward_y"] = False elif key == keyboard.Key.right: self.key_states["backward_y"] = False elif key == keyboard.Key.shift: self.key_states["backward_z"] = False elif key == keyboard.Key.shift_r: self.key_states["forward_z"] = False elif key == keyboard.Key.ctrl_r: self.open_gripper_command = False elif key == keyboard.Key.ctrl_l: self.close_gripper_command = False except AttributeError: pass self.listener = keyboard.Listener(on_press=on_press, on_release=on_release) self.listener.start() print("Keyboard controls:") print(" Arrow keys: Move in X-Y plane") print(" Shift and Shift_R: Move in Z axis") print(" Right Ctrl and Left Ctrl: Open and close gripper") print(" Enter: End episode with SUCCESS") print(" Backspace: End episode with FAILURE") print(" Space: Start/Stop Intervention") print(" ESC: Exit") def stop(self): """Stop the keyboard listener.""" if self.listener and self.listener.is_alive(): self.listener.stop() def get_deltas(self): """Get the current movement deltas from keyboard state.""" delta_x = delta_y = delta_z = 0.0 if self.key_states["forward_x"]: delta_x += self.x_step_size if self.key_states["backward_x"]: delta_x -= self.x_step_size if self.key_states["forward_y"]: delta_y += self.y_step_size if self.key_states["backward_y"]: delta_y -= self.y_step_size if self.key_states["forward_z"]: delta_z += self.z_step_size if self.key_states["backward_z"]: delta_z -= self.z_step_size return delta_x, delta_y, delta_z def should_save(self): """Return True if Enter was pressed (save episode).""" return self.key_states["success"] or self.key_states["failure"] def should_intervene(self): """Return True if intervention flag was set.""" return self.key_states["intervention"] def reset(self): """Reset the controller.""" for key in self.key_states: self.key_states[key] = False class GamepadController(InputController): """Generate motion deltas from gamepad input.""" def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1, config_path=None): super().__init__(x_step_size, y_step_size, z_step_size) self.deadzone = deadzone self.joystick = None self.intervention_flag = False self.config_path = config_path self.controller_config = None def start(self): """Initialize pygame and the gamepad.""" import pygame pygame.init() pygame.joystick.init() if pygame.joystick.get_count() == 0: print("No gamepad detected. Please connect a gamepad and try again.") self.running = False return self.joystick = pygame.joystick.Joystick(0) self.joystick.init() joystick_name = self.joystick.get_name() print(f"Initialized gamepad: {joystick_name}") # Load controller configuration based on joystick name self.controller_config = load_controller_config(joystick_name, self.config_path) # Get button mappings from config buttons = self.controller_config.get("buttons", {}) print("Gamepad controls:") print(f" {buttons.get('rb', 'RB')} button: Intervention") print(" Left analog stick: Move in X-Y plane") print(" Right analog stick (vertical): Move in Z axis") print(f" {buttons.get('lt', 'LT')} button: Close gripper") print(f" {buttons.get('rt', 'RT')} button: Open gripper") print(f" {buttons.get('b', 'B')}/Circle button: Exit") print(f" {buttons.get('y', 'Y')}/Triangle button: End episode with SUCCESS") print(f" {buttons.get('a', 'A')}/Cross button: End episode with FAILURE") print(f" {buttons.get('x', 'X')}/Square button: Rerecord episode") def stop(self): """Clean up pygame resources.""" import pygame if pygame.joystick.get_init(): if self.joystick: self.joystick.quit() pygame.joystick.quit() pygame.quit() def update(self): """Process pygame events to get fresh gamepad readings.""" import pygame # Get button mappings from config buttons = self.controller_config.get("buttons", {}) y_button = buttons.get("y", 3) # Default to 3 if not found a_button = buttons.get("a", 0) # Default to 0 if not found (Logitech F310) x_button = buttons.get("x", 2) # Default to 2 if not found (Logitech F310) lt_button = buttons.get("lt", 6) # Default to 6 if not found rt_button = buttons.get("rt", 7) # Default to 7 if not found rb_button = buttons.get("rb", 5) # Default to 5 if not found for event in pygame.event.get(): if event.type == pygame.JOYBUTTONDOWN: if event.button == y_button: self.episode_end_status = "success" elif event.button == a_button: self.episode_end_status = "failure" elif event.button == x_button: self.episode_end_status = "rerecord_episode" elif event.button == lt_button: self.close_gripper_command = True elif event.button == rt_button: self.open_gripper_command = True # Reset episode status on button release elif event.type == pygame.JOYBUTTONUP: if event.button in [x_button, a_button, y_button]: self.episode_end_status = None elif event.button == lt_button: self.close_gripper_command = False elif event.button == rt_button: self.open_gripper_command = False # Check for RB button for intervention flag if self.joystick.get_button(rb_button): self.intervention_flag = True else: self.intervention_flag = False def get_deltas(self): """Get the current movement deltas from gamepad state.""" import pygame try: # Get axis mappings from config axes = self.controller_config.get("axes", {}) axis_inversion = self.controller_config.get("axis_inversion", {}) # Get axis indices from config (with defaults if not found) left_x_axis = axes.get("left_x", 0) left_y_axis = axes.get("left_y", 1) right_y_axis = axes.get("right_y", 3) # Get axis inversion settings (with defaults if not found) invert_left_x = axis_inversion.get("left_x", False) invert_left_y = axis_inversion.get("left_y", True) invert_right_y = axis_inversion.get("right_y", True) # Read joystick axes x_input = self.joystick.get_axis(left_x_axis) # Left/Right y_input = self.joystick.get_axis(left_y_axis) # Up/Down z_input = self.joystick.get_axis(right_y_axis) # Up/Down for Z # Apply deadzone to avoid drift x_input = 0 if abs(x_input) < self.deadzone else x_input y_input = 0 if abs(y_input) < self.deadzone else y_input z_input = 0 if abs(z_input) < self.deadzone else z_input # Apply inversion if configured if invert_left_x: x_input = -x_input if invert_left_y: y_input = -y_input if invert_right_y: z_input = -z_input # Calculate deltas delta_x = y_input * self.y_step_size # Forward/backward delta_y = x_input * self.x_step_size # Left/right delta_z = z_input * self.z_step_size # Up/down return delta_x, delta_y, delta_z except pygame.error: print("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 class GamepadControllerHID(InputController): """Generate motion deltas from gamepad input using HIDAPI.""" def __init__( self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1, ): """ Initialize the HID gamepad controller. Args: step_size: Base movement step size in meters z_scale: Scaling factor for Z-axis movement deadzone: Joystick deadzone to prevent drift """ super().__init__(x_step_size, y_step_size, z_step_size) self.deadzone = deadzone self.device = None self.device_info = None # Movement values (normalized from -1.0 to 1.0) self.left_x = 0.0 self.left_y = 0.0 self.right_x = 0.0 self.right_y = 0.0 # Button states self.buttons = {} self.quit_requested = False self.save_requested = False def find_device(self): """Look for the gamepad device by vendor and product ID.""" import hid devices = hid.enumerate() for device in devices: device_name = device["product_string"] if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]): return device print("No gamepad found, check the connection and the product string in HID to add your gamepad") return None def start(self): """Connect to the gamepad using HIDAPI.""" import hid self.device_info = self.find_device() if not self.device_info: self.running = False return try: print(f"Connecting to gamepad at path: {self.device_info['path']}") self.device = hid.device() self.device.open_path(self.device_info["path"]) self.device.set_nonblocking(1) manufacturer = self.device.get_manufacturer_string() product = self.device.get_product_string() print(f"Connected to {manufacturer} {product}") print("Gamepad controls (HID mode):") print(" Left analog stick: Move in X-Y plane") print(" Right analog stick: Move in Z axis (vertical)") print(" Button 1/B/Circle: Exit") print(" Button 2/A/Cross: End episode with SUCCESS") print(" Button 3/X/Square: End episode with FAILURE") except OSError as e: print(f"Error opening gamepad: {e}") print("You might need to run this with sudo/admin privileges on some systems") self.running = False def stop(self): """Close the HID device connection.""" if self.device: self.device.close() self.device = None def update(self): """ Read and process the latest gamepad data. Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading """ for _ in range(10): self._update() def _update(self): """Read and process the latest gamepad data.""" if not self.device or not self.running: return try: # Read data from the gamepad data = self.device.read(64) # Interpret gamepad data - this will vary by controller model # These offsets are for the Logitech RumblePad 2 if data and len(data) >= 8: # Normalize joystick values from 0-255 to -1.0-1.0 self.left_x = (data[1] - 128) / 128.0 self.left_y = (data[2] - 128) / 128.0 self.right_x = (data[3] - 128) / 128.0 self.right_y = (data[4] - 128) / 128.0 # Apply deadzone self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y # Parse button states (byte 5 in the Logitech RumblePad 2) buttons = data[5] # Check if RB is pressed then the intervention flag should be set self.intervention_flag = data[6] in [2, 6, 10, 14] # Check if RT is pressed self.open_gripper_command = data[6] in [8, 10, 12] # Check if LT is pressed self.close_gripper_command = data[6] in [4, 6, 12] # Check if Y/Triangle button (bit 7) is pressed for saving # Check if X/Square button (bit 5) is pressed for failure # Check if A/Cross button (bit 4) is pressed for rerecording if buttons & 1 << 7: self.episode_end_status = "success" elif buttons & 1 << 5: self.episode_end_status = "failure" elif buttons & 1 << 4: self.episode_end_status = "rerecord_episode" else: self.episode_end_status = None except OSError as e: print(f"Error reading from gamepad: {e}") def get_deltas(self): """Get the current movement deltas from gamepad state.""" # Calculate deltas - invert as needed based on controller orientation delta_x = -self.left_y * self.x_step_size # Forward/backward delta_y = -self.left_x * self.y_step_size # Left/right delta_z = -self.right_y * self.z_step_size # Up/down return delta_x, delta_y, delta_z def should_quit(self): """Return True if quit button was pressed.""" return self.quit_requested def should_save(self): """Return True if save button was pressed.""" return self.save_requested