mujoco_py/mjrenderpool.py (148 lines of code) (raw):

import ctypes import inspect from multiprocessing import Array, get_start_method, Pool, Value import numpy as np class RenderPoolStorage: """ Helper object used for storing global data for worker processes. """ __slots__ = ['shared_rgbs_array', 'shared_depths_array', 'device_id', 'sim', 'modder'] class MjRenderPool: """ Utilizes a process pool to render a MuJoCo simulation across multiple GPU devices. This can scale the throughput linearly with the number of available GPUs. Throughput can also be slightly increased by using more than one worker per GPU. """ DEFAULT_MAX_IMAGE_SIZE = 512 * 512 # in pixels def __init__(self, model, device_ids=1, n_workers=None, max_batch_size=None, max_image_size=DEFAULT_MAX_IMAGE_SIZE, modder=None): """ Args: - model (PyMjModel): MuJoCo model to use for rendering - device_ids (int/list): list of device ids to use for rendering. One or more workers will be assigned to each device, depending on how many workers are requested. - n_workers (int): number of parallel processes in the pool. Defaults to the number of device ids. - max_batch_size (int): maximum number of states that can be rendered in batch using .render(). Defaults to the number of workers. - max_image_size (int): maximum number pixels in images requested by .render() - modder (Modder): modder to use for domain randomization. """ self._closed, self.pool = False, None if not (modder is None or inspect.isclass(modder)): raise ValueError("modder must be a class") if isinstance(device_ids, int): device_ids = list(range(device_ids)) else: assert isinstance(device_ids, list), ( "device_ids must be list of integer") n_workers = n_workers or 1 self._max_batch_size = max_batch_size or (len(device_ids) * n_workers) self._max_image_size = max_image_size array_size = self._max_image_size * self._max_batch_size self._shared_rgbs = Array(ctypes.c_uint8, array_size * 3) self._shared_depths = Array(ctypes.c_float, array_size) self._shared_rgbs_array = np.frombuffer( self._shared_rgbs.get_obj(), dtype=ctypes.c_uint8) assert self._shared_rgbs_array.size == (array_size * 3), ( "Array size is %d, expected %d" % ( self._shared_rgbs_array.size, array_size * 3)) self._shared_depths_array = np.frombuffer( self._shared_depths.get_obj(), dtype=ctypes.c_float) assert self._shared_depths_array.size == array_size, ( "Array size is %d, expected %d" % ( self._shared_depths_array.size, array_size)) worker_id = Value(ctypes.c_int) worker_id.value = 0 if get_start_method() != "spawn": raise RuntimeError( "Start method must be set to 'spawn' for the " "render pool to work. That is, you must add the " "following to the _TOP_ of your main script, " "before any other imports (since they might be " "setting it otherwise):\n" " import multiprocessing as mp\n" " if __name__ == '__main__':\n" " mp.set_start_method('spawn')\n") self.pool = Pool( processes=len(device_ids) * n_workers, initializer=MjRenderPool._worker_init, initargs=( model.get_mjb(), worker_id, device_ids, self._shared_rgbs, self._shared_depths, modder)) @staticmethod def _worker_init(mjb_bytes, worker_id, device_ids, shared_rgbs, shared_depths, modder): """ Initializes the global state for the workers. """ s = RenderPoolStorage() with worker_id.get_lock(): proc_worker_id = worker_id.value worker_id.value += 1 s.device_id = device_ids[proc_worker_id % len(device_ids)] s.shared_rgbs_array = np.frombuffer( shared_rgbs.get_obj(), dtype=ctypes.c_uint8) s.shared_depths_array = np.frombuffer( shared_depths.get_obj(), dtype=ctypes.c_float) # avoid a circular import from mujoco_py import load_model_from_mjb, MjRenderContext, MjSim s.sim = MjSim(load_model_from_mjb(mjb_bytes)) # attach a render context to the sim (needs to happen before # modder is called, since it might need to upload textures # to the GPU). MjRenderContext(s.sim, device_id=s.device_id) if modder is not None: s.modder = modder(s.sim, random_state=proc_worker_id) s.modder.whiten_materials() else: s.modder = None global _render_pool_storage _render_pool_storage = s @staticmethod def _worker_render(worker_id, state, width, height, camera_name, randomize): """ Main target function for the workers. """ s = _render_pool_storage forward = False if state is not None: s.sim.set_state(state) forward = True if randomize and s.modder is not None: s.modder.randomize() forward = True if forward: s.sim.forward() rgb_block = width * height * 3 rgb_offset = rgb_block * worker_id rgb = s.shared_rgbs_array[rgb_offset:rgb_offset + rgb_block] rgb = rgb.reshape(height, width, 3) depth_block = width * height depth_offset = depth_block * worker_id depth = s.shared_depths_array[depth_offset:depth_offset + depth_block] depth = depth.reshape(height, width) rgb[:], depth[:] = s.sim.render( width, height, camera_name=camera_name, depth=True, device_id=s.device_id) def render(self, width, height, states=None, camera_name=None, depth=False, randomize=False, copy=True): """ Renders the simulations in batch. If no states are provided, the max_batch_size will be used. Args: - width (int): width of image to render. - height (int): height of image to render. - states (list): list of MjSimStates; updates the states before rendering. Batch size will be number of states supplied. - camera_name (str): name of camera to render from. - depth (bool): if True, also return depth. - randomize (bool): calls modder.rand_all() before rendering. - copy (bool): return a copy rather than a reference Returns: - rgbs: NxHxWx3 numpy array of N images in batch of width W and height H. - depth: NxHxW numpy array of N images in batch of width W and height H. Only returned if depth=True. """ if self._closed: raise RuntimeError("The pool has been closed.") if (width * height) > self._max_image_size: raise ValueError( "Requested image larger than maximum image size. Create " "a new RenderPool with a larger maximum image size.") if states is None: batch_size = self._max_batch_size states = [None] * batch_size else: batch_size = len(states) if batch_size > self._max_batch_size: raise ValueError( "Requested batch size larger than max batch size. Create " "a new RenderPool with a larger max batch size.") self.pool.starmap( MjRenderPool._worker_render, [(i, state, width, height, camera_name, randomize) for i, state in enumerate(states)]) rgbs = self._shared_rgbs_array[:width * height * 3 * batch_size] rgbs = rgbs.reshape(batch_size, height, width, 3) if copy: rgbs = rgbs.copy() if depth: depths = self._shared_depths_array[:width * height * batch_size] depths = depths.reshape(batch_size, height, width).copy() if copy: depths = depths.copy() return rgbs, depths else: return rgbs def close(self): """ Closes the pool and terminates child processes. """ if not self._closed: if self.pool is not None: self.pool.close() self.pool.join() self._closed = True def __del__(self): self.close()