gym_genesis/env.py (88 lines of code) (raw):

import gymnasium as gym import genesis as gs import numpy as np from gymnasium import spaces import warnings from gym_genesis.tasks.cube_pick import CubePick from gym_genesis.tasks.cube_stack import CubeStack class GenesisEnv(gym.Env): metadata = {"render_modes": ["rgb_array"], "render_fps": 50} def __init__( self, task, enable_pixels = False, observation_height = 480, observation_width = 640, num_envs = 1, env_spacing = (1.0, 1.0), render_mode=None, camera_capture_mode="per_env", # or "global" strip_environment_state = True, ): super().__init__() self.task = task self.enable_pixels = enable_pixels self.observation_height = observation_height self.observation_width = observation_width self.num_envs = num_envs self.env_spacing = env_spacing self.render_mode = render_mode self.camera_capture_mode = camera_capture_mode self.strip_environment_state = strip_environment_state self._env = self._make_env_task(self.task) self.observation_space = self._env.observation_space self.action_space = self._env.action_space # === Set up Genesis scene (task-specific env will populate it) === self.scene = None # Will be created in the child class def reset(self, seed=None, options=None): super().reset(seed=seed) if seed is not None: self._env.seed(seed) observation = self._env.reset() info = {"is_success": [False] * self.num_envs} return observation, info def step(self, action): _, reward, _, observation = self._env.step(action) is_success = (reward == 1) terminated = np.array(is_success, dtype=bool) truncated = np.zeros(self.num_envs, dtype=bool) # All False info = {"is_success": is_success.tolist()} return observation, reward, terminated, truncated, info def save_video(self, save_video: bool = False, file_name: str = "episode.mp4", fps=60): if self.enable_pixels and save_video: warnings.warn( "Calling `save_video()` will immediately stop the camera recording. " "You will not be able to record additional frames after this call. " "Call this method only when you are finished recording your episode.", stacklevel=2, ) self._env.cam.stop_recording(save_to_filename=file_name, fps=fps) def close(self): pass def get_obs(self): return self._env.get_obs() def get_robot(self): #TODO: (jadechovhari) add assertion that a robot exist return self._env.franka def render(self): return self._env.cam.render()[0] if self.enable_pixels else None def _make_env_task(self, task_name): if task_name == "cube_pick": task = CubePick(enable_pixels=self.enable_pixels, observation_height=self.observation_height, observation_width=self.observation_width, num_envs = self.num_envs, env_spacing = self.env_spacing, camera_capture_mode = self.camera_capture_mode, strip_environment_state=self.strip_environment_state, ) elif task_name == "cube_stack": task = CubeStack(enable_pixels=self.enable_pixels, observation_height=self.observation_height, observation_width=self.observation_width, num_envs = self.num_envs, env_spacing = self.env_spacing, camera_capture_mode = self.camera_capture_mode, strip_environment_state=self.strip_environment_state, ) else: raise NotImplementedError(task_name) return task