mujoco_worldgen/util/sim_funcs.py (85 lines of code) (raw):

import logging import numpy as np import itertools logger = logging.getLogger(__name__) # ####################################### # ############ set_action ############### # ####################################### def ctrl_set_action(sim, action): """ For torque actuators it copies the action into mujoco ctrl field. For position actuators it sets the target relative to the current qpos. """ if sim.model.nmocap > 0: _, action = np.split(action, (sim.model.nmocap * 7, )) if sim.data.ctrl is not None: for i in range(action.shape[0]): if sim.model.actuator_biastype[i] == 0: sim.data.ctrl[i] = action[i] else: idx = sim.model.jnt_qposadr[sim.model.actuator_trnid[i, 0]] sim.data.ctrl[i] = sim.data.qpos[idx] + action[i] # ####################################### # ############ get_reward ############### # ####################################### def zero_get_reward(sim): return 0.0 def gps_dist(sim, obj0, obj1): obj0 = sim.data.get_site_xpos(obj0) obj1 = sim.data.get_site_xpos(obj1) diff = np.sum(np.square(obj0 - obj1)) return diff + 0.3 * np.log(diff + 1e-4) def l2_dist(sim, obj0, obj1): obj0 = sim.data.get_site_xpos(obj0) obj1 = sim.data.get_site_xpos(obj1) return np.sqrt(np.mean(np.square(obj0 - obj1))) # ####################################### # ########### get_diverged ############## # ####################################### def false_get_diverged(sim): return False, 0.0 def simple_get_diverged(sim): if sim.data.qpos is not None and \ (np.max(np.abs(sim.data.qpos)) > 1000.0 or np.max(np.abs(sim.data.qvel)) > 100.0): return True, -20.0 return False, 0.0 # ####################################### # ########### get_info ############## # ####################################### def empty_get_info(sim): return {} # ####################################### # ############## get_obs ################ # ####################################### def flatten_get_obs(sim): if sim.data.qpos is None: return np.zeros(0) return np.concatenate([sim.data.qpos, sim.data.qvel]) def image_get_obs(sim): return sim.render(100, 100, camera_name="rgb") # Helpers def get_body_geom_ids(model, body_name): """ Returns geom_ids in the body. """ body_id = model.body_name2id(body_name) geom_ids = [] for geom_id in range(model.ngeom): if model.geom_bodyid[geom_id] == body_id: geom_ids.append(geom_id) return geom_ids def change_geom_alpha(model, body_name_prefix, new_alpha): ''' Changes the visual transparency (alpha) of an object''' for body_name in model.body_names: if body_name.startswith(body_name_prefix): for geom_id in get_body_geom_ids(model, body_name): model.geom_rgba[geom_id, 3] = new_alpha def joint_qpos_idxs(sim, joint_name): ''' Gets indexes for the specified joint's qpos values''' addr = sim.model.get_joint_qpos_addr(joint_name) if isinstance(addr, tuple): return list(range(addr[0], addr[1])) else: return [addr] def qpos_idxs_from_joint_prefix(sim, prefix): ''' Gets indexes for the qpos values of all joints matching the prefix''' qpos_idxs_list = [joint_qpos_idxs(sim, name) for name in sim.model.joint_names if name.startswith(prefix)] return list(itertools.chain.from_iterable(qpos_idxs_list)) def joint_qvel_idxs(sim, joint_name): ''' Gets indexes for the specified joint's qvel values''' addr = sim.model.get_joint_qvel_addr(joint_name) if isinstance(addr, tuple): return list(range(addr[0], addr[1])) else: return [addr] def qvel_idxs_from_joint_prefix(sim, prefix): ''' Gets indexes for the qvel values of all joints matching the prefix''' qvel_idxs_list = [joint_qvel_idxs(sim, name) for name in sim.model.joint_names if name.startswith(prefix)] return list(itertools.chain.from_iterable(qvel_idxs_list)) def body_names_from_joint_prefix(sim, prefix): ''' Returns a list of body names that contain joints matching the given prefix''' return [sim.model.body_id2name(sim.model.jnt_bodyid[sim.model.joint_name2id(name)]) for name in sim.model.joint_names if name.startswith(prefix)]