mujoco_worldgen/builder.py (85 lines of code) (raw):

import logging from collections import OrderedDict from copy import deepcopy as copy import numpy as np from mujoco_py import const, load_model_from_xml, MjSim from mujoco_worldgen.util.path import worldgen_path from mujoco_worldgen.objs.obj import Obj from mujoco_worldgen.parser import unparse_dict, update_mujoco_dict logger = logging.getLogger(__name__) class WorldBuilder(Obj): classname = 'worldbuilder' # Used for __repr__ def __init__(self, world_params, seed): self.world_params = copy(world_params) self.random_state = np.random.RandomState(seed) super(WorldBuilder, self).__init__() # Normally size is set during generate() but we are the top level world self.size = world_params.size # Normally relative_position is set by our parent but we are the root. self.relative_position = (0, 0) def append(self, obj): super(WorldBuilder, self).append(obj, "top") return self def generate_xml_dict(self): ''' Get the mujoco header XML dict. It contains compiler, size and option nodes. ''' compiler = OrderedDict() compiler['@angle'] = 'radian' compiler['@coordinate'] = 'local' compiler['@meshdir'] = worldgen_path('assets/stls') compiler['@texturedir'] = worldgen_path('assets/textures') option = OrderedDict() option["flag"] = OrderedDict([("@warmstart", "enable")]) return OrderedDict([('compiler', compiler), ('option', option)]) def generate_xinit(self): return {} # Builder has no xinit def to_xml_dict(self): ''' Generates XML for this object and all of its children. see generate_xml() for parameter documentation. Builder applies transform to all the children. Returns merged xml_dict ''' xml_dict = self.generate_xml_dict() assert len(self.markers) == 0, "Can't mark builder object." # Then add the xml of all of our children for children in self.children.values(): for child, _ in children: child_dict = child.to_xml_dict() update_mujoco_dict(xml_dict, child_dict) for transform in self.transforms: transform(xml_dict) return xml_dict def get_sim(self): self.placements = OrderedDict() self.placements["top"] = {"origin": np.zeros(3), "size": self.world_params.size} name_indexes = OrderedDict() self.to_names(name_indexes) res = self.compile(self.random_state, world_params=self.world_params) if not res: raise FullVirtualWorldException('Failed to compile world') self.set_absolute_position((0, 0, 0)) # Recursively set all positions xml_dict = self.to_xml_dict() xinit_dict = self.to_xinit() udd_callbacks = self.to_udd_callback() xml = unparse_dict(xml_dict) model = load_model_from_xml(xml) sim = MjSim(model, nsubsteps=self.world_params.num_substeps) for name, value in xinit_dict.items(): sim.data.set_joint_qpos(name, value) # Places mocap where related bodies are. if sim.model.nmocap > 0 and sim.model.eq_data is not None: for i in range(sim.model.eq_data.shape[0]): if sim.model.eq_type[i] == const.EQ_WELD: sim.model.eq_data[i, :] = np.array( [0., 0., 0., 1., 0., 0., 0.]) udd_callbacks = (udd_callbacks or []) if udd_callbacks is not None and len(udd_callbacks) > 0: def merged_udd_callback(sim): ret = {} for udd_callback in udd_callbacks: ret.update(udd_callback(sim)) return ret sim.udd_callback = merged_udd_callback return sim class FullVirtualWorldException(Exception): def __init__(self, msg=''): Exception.__init__(self, "Virtual world is full of objects. " + "Cannot allocate more of them. " + msg)