in mujoco_py/mjrenderpool.py [0:0]
def __init__(self, model, device_ids=1, n_workers=None,
max_batch_size=None, max_image_size=DEFAULT_MAX_IMAGE_SIZE,
modder=None):
"""
Args:
- model (PyMjModel): MuJoCo model to use for rendering
- device_ids (int/list): list of device ids to use for rendering.
One or more workers will be assigned to each device, depending
on how many workers are requested.
- n_workers (int): number of parallel processes in the pool. Defaults
to the number of device ids.
- max_batch_size (int): maximum number of states that can be rendered
in batch using .render(). Defaults to the number of workers.
- max_image_size (int): maximum number pixels in images requested
by .render()
- modder (Modder): modder to use for domain randomization.
"""
self._closed, self.pool = False, None
if not (modder is None or inspect.isclass(modder)):
raise ValueError("modder must be a class")
if isinstance(device_ids, int):
device_ids = list(range(device_ids))
else:
assert isinstance(device_ids, list), (
"device_ids must be list of integer")
n_workers = n_workers or 1
self._max_batch_size = max_batch_size or (len(device_ids) * n_workers)
self._max_image_size = max_image_size
array_size = self._max_image_size * self._max_batch_size
self._shared_rgbs = Array(ctypes.c_uint8, array_size * 3)
self._shared_depths = Array(ctypes.c_float, array_size)
self._shared_rgbs_array = np.frombuffer(
self._shared_rgbs.get_obj(), dtype=ctypes.c_uint8)
assert self._shared_rgbs_array.size == (array_size * 3), (
"Array size is %d, expected %d" % (
self._shared_rgbs_array.size, array_size * 3))
self._shared_depths_array = np.frombuffer(
self._shared_depths.get_obj(), dtype=ctypes.c_float)
assert self._shared_depths_array.size == array_size, (
"Array size is %d, expected %d" % (
self._shared_depths_array.size, array_size))
worker_id = Value(ctypes.c_int)
worker_id.value = 0
if get_start_method() != "spawn":
raise RuntimeError(
"Start method must be set to 'spawn' for the "
"render pool to work. That is, you must add the "
"following to the _TOP_ of your main script, "
"before any other imports (since they might be "
"setting it otherwise):\n"
" import multiprocessing as mp\n"
" if __name__ == '__main__':\n"
" mp.set_start_method('spawn')\n")
self.pool = Pool(
processes=len(device_ids) * n_workers,
initializer=MjRenderPool._worker_init,
initargs=(
model.get_mjb(),
worker_id,
device_ids,
self._shared_rgbs,
self._shared_depths,
modder))