scripts/gen_wrappers.py (773 lines of code) (raw):

#!/usr/bin/env python import codecs import os import pycparser import re import subprocess import sys import tempfile import traceback from collections import OrderedDict from pycparser.c_ast import ArrayDecl, TypeDecl, PtrDecl, Union, BinaryOp, UnaryOp c_compiler = "cc" if sys.platform.startswith("win"): c_compiler = "cl" def tryint(x): try: return int(x) except: return x def get_struct_dict(struct, struct_name, array_shapes): struct_dict = OrderedDict() struct_dict[struct_name] = OrderedDict([('scalars', []), ('arrays', []), ('arrays2d', []), ('ptrs', []), ('depends_on_model', False)]) for child in struct.children(): child_name = child[1].name child_type = child[1].type decl = child[1].children()[0][1] if isinstance(child_type, ArrayDecl): if hasattr(decl.type.type, "names"): array_type = ' '.join(decl.type.type.names) array_size = extract_size_info(decl.dim) struct_dict[struct_name]['arrays'].append((child_name, array_type, array_size)) elif isinstance(decl.type, PtrDecl): print("skipping pointer array: %s.%s" % (struct_name, child_name)) continue elif hasattr(decl.type.type.type, "names"): # assuming a 2d array array_type = ' '.join(decl.type.type.type.names) s1 = extract_size_info(decl.dim) s2 = extract_size_info(decl.type.dim) struct_dict[struct_name]['arrays2d'].append((child_name, array_type, (s1, s2))) else: print("skipping unknown array case: %s.%s\n%s" % (struct_name, child_name, child)) elif isinstance(child_type, TypeDecl): if isinstance(decl.type, pycparser.c_ast.Struct): fixed_name = decl.declname if fixed_name == "global": fixed_name = "global_" name = struct_name + "_" + fixed_name child_struct_dict = get_struct_dict( decl.type, name, array_shapes) struct_dict = OrderedDict(struct_dict, **child_struct_dict) struct_dict[struct_name]['scalars'].append((fixed_name, name)) else: field_type = ' '.join(decl.type.names) struct_dict[struct_name]['scalars'].append( (child_name, field_type)) elif isinstance(child_type, PtrDecl): ptr_type = ' '.join(decl.type.type.names) n = struct_name + '.' + child_name if n not in array_shapes: print('Warning: skipping {} due to unknown shape'.format(n)) else: struct_dict[struct_name]['ptrs'].append( (child_name, ptr_type, array_shapes[n])) # Structs needing array shapes must get them through mjModel # but mjModel itself doesn't need to be passed an extra mjModel. # TODO: depends_on_model should be set to True if any member of this struct depends on mjModel # but currently that never happens. if struct_name != 'mjModel': struct_dict[struct_name]['depends_on_model'] = True elif isinstance(child_type, Union): # I'm ignoring unions for now until we think they're necessary continue else: raise NotImplementedError assert isinstance(struct_dict, OrderedDict), 'Must be deterministic' return struct_dict def extract_size_info(node): """ Try to extract what integer value (or named reference) `node` contains. Can handle pycparser.c_ast.Constant pycparser.c_ast.BinaryOp(op="*", left, right) pycparser.c_ast.ID as long as `left` and `right` are either `Constant` or BinaryOp's that ultimately evaluate to constants. :param node: The AST node representing an integer constant. :return: The value """ if isinstance(node, pycparser.c_ast.ID): return node.name elif isinstance(node, BinaryOp): if node.op == "*": return extract_size_info(node.left) * extract_size_info(node.right) elif isinstance(node, pycparser.c_ast.Constant): return int(node.value) raise NotImplementedError(str(node)) def format_size_argument(model_var_name, shape_def): if isinstance(shape_def, str): if shape_def.startswith('n'): return f'{model_var_name}.{shape_def}' m = re.match(r'(\d+)\s*\*\s*(n[a-zA-Z]*)', shape_def) if m: return f'{m.group(1)}*{model_var_name}.{m.group(2)}' return shape_def def get_full_scr_lines(HEADER_DIR, HEADER_FILES): # ===== Read all header files ===== file_contents = [] for filename in HEADER_FILES: # mujoco 2.0 header files fail when parsed as utf-8 with codecs.open(os.path.join(HEADER_DIR, filename), 'r', encoding='latin-1') as f: file_contents.append(f.read()) full_src_lines = [line.strip() for line in '\n'.join(file_contents).splitlines()] return full_src_lines def get_array_shapes(full_src_lines): # ===== Parse array shape hints ===== array_shapes = OrderedDict() curr_struct_name = None for line in full_src_lines: # Current struct name m = re.match(r'struct (\w+)', line) if m: curr_struct_name = m.group(1) continue # Pointer with a shape comment m = re.match(r'\s*\w+\s*\*\s+(\w+);\s*//.*\((.+) x (.+)\)$', line) if m: name = curr_struct_name[1:] + '.' + m.group(1) assert name not in array_shapes array_shapes[name] = (tryint(m.group(2)), tryint(m.group(3))) return array_shapes def get_processed_src(HEADER_DIR, full_src_lines): # ===== Preprocess header files ===== with tempfile.NamedTemporaryFile(suffix='.h', delete=False) as f: f.write('\n'.join(full_src_lines).encode()) f.flush() print("Saved all header information to: %s" % f.name) # -E: run preprocessor only # -P: don't generate debug lines starting with # # -I: include directory processed_src = subprocess.check_output( [c_compiler, '-E', '-P', '-I', HEADER_DIR, f.name]).decode() return processed_src def get_full_struct_dict(processed_src, array_shapes): # ===== Parse and extract structs ===== ast = pycparser.c_parser.CParser().parse(processed_src) struct_dict = OrderedDict() for node in ast.children(): assert (node[1].name is None) == isinstance( node[1].type, pycparser.c_ast.Struct) if isinstance(node[1].type, pycparser.c_ast.Struct): (_, struct), = node[1].children() assert struct.name.startswith('_mj') struct_name = struct.name[1:] # take out leading underscore assert struct_name not in struct_dict struct_dict = struct_dict.copy() struct_dict.update(get_struct_dict(struct, struct_name, array_shapes)) assert isinstance(struct_dict, OrderedDict), 'Must be deterministic' return struct_dict def get_const_from_enum(processed_src): # ===== Parse and extract structs ===== ast = pycparser.c_parser.CParser().parse(processed_src) lines = [] for node in ast.children(): assert (node[1].name is None) == isinstance( node[1].type, pycparser.c_ast.Struct) struct = node[1].children()[0][1] # Check if it is an enum if hasattr(struct, "type") and isinstance(struct.type, pycparser.c_ast.Enum): lines.append(" # " + struct.type.name) # enum list os a list of key-value enumerations enumlist = struct.children()[0][1].children()[0][1].children() last_value = None for _, enum in enumlist: var = enum.name[2:] # Enum has two parts - name and value if enum.value is not None: if isinstance(enum.value, BinaryOp): # An enum is actually a binary operation if enum.value.op == '<<': # Parse and evaluate simple constant expression. Will throw if it's anything more complex value = int(enum.value.children()[0][1].value) << int(enum.value.children()[1][1].value) else: raise NotImplementedError elif isinstance(enum.value, UnaryOp): # If we want to be correct we need to do a bit of parsing here.... if enum.value.op == '-': # Again, if some assumptions I'm making here are not correct, this should throw value = -int(enum.value.expr.value) else: raise NotImplementedError else: children = enum.value.children() if len(children) > 0: value = children[1][1].value else: value = enum.value.value value = int(value) last_value = value new_line = str(var) + " = " + str(value) lines.append(new_line) else: assert(last_value is not None) last_value += 1 lines.append(str(var) + " = " + str(last_value)) lines.append("") return lines def get_struct_wrapper(struct_dict): # ===== Generate code ===== structname2wrappername = OrderedDict() structname2wrapfuncname = OrderedDict() for name in struct_dict: assert name.startswith('mj') structname2wrappername[name] = 'PyMj' + name[2:] structname2wrapfuncname[name] = 'WrapMj' + name[2:] return structname2wrappername, structname2wrapfuncname def _add_named_access_methods(obj_type, attr_name, attr_name_short): getter_name = obj_type if attr_name_short is not None: getter_name += "_" + attr_name_short reshape_suffix = ".reshape((3, 3))" if attr_name.endswith('mat') else '' code = """ def get_{getter_name}(self, name): id = self._model.{obj_type}_name2id(name) return self._{attr_name}[id]{reshape_suffix}\n""".format( obj_type=obj_type, getter_name=getter_name, attr_name=attr_name, reshape_suffix=reshape_suffix) if getter_name != attr_name: code += """ def get_{attr_name}(self, name): raise RuntimeError("get_{getter_name} should be used instead of get_{attr_name}")\n""".format( getter_name=getter_name, attr_name=attr_name) return code def _add_named_jacobian_methods(obj_type): cap_type = obj_type.title() # Capitalized code = """ def get_{obj_type}_jacp(self, name, np.ndarray[double, ndim=1, mode="c"] jacp = None): id = self._model.{obj_type}_name2id(name) if jacp is None: jacp = np.zeros(3 * self._model.nv) cdef double * jacp_view = &jacp[0] mj_jac{cap_type}(self._model.ptr, self.ptr, jacp_view, NULL, id) return jacp def get_{obj_type}_jacr(self, name, np.ndarray[double, ndim=1, mode="c"] jacr = None): id = self._model.{obj_type}_name2id(name) if jacr is None: jacr = np.zeros(3 * self._model.nv) cdef double * jacr_view = &jacr[0] mj_jac{cap_type}(self._model.ptr, self.ptr, NULL, jacr_view, id) return jacr def get_{obj_type}_xvelp(self, name): id = self._model.{obj_type}_name2id(name) jacp = self.get_{obj_type}_jacp(name).reshape((3, self._model.nv)) xvelp = np.dot(jacp, self.qvel) return xvelp def get_{obj_type}_xvelr(self, name): id = self._model.{obj_type}_name2id(name) jacr = self.get_{obj_type}_jacr(name).reshape((3, self._model.nv)) xvelr = np.dot(jacr, self.qvel) return xvelr\n""".format(obj_type=obj_type, cap_type=cap_type) return code def _add_jacobian_getters(obj_type): cap_type = obj_type.title() # Capitalized code = ''' @property def {obj_type}_jacp(self): jacps = np.zeros((self._model.n{obj_type}, 3 * self._model.nv)) cdef double [:] jacp_view for i, jacp in enumerate(jacps): jacp_view = jacp mj_jac{cap_type}(self._model.ptr, self.ptr, &jacp_view[0], NULL, i) return jacps @property def {obj_type}_jacr(self): jacrs = np.zeros((self._model.n{obj_type}, 3 * self._model.nv)) cdef double [:] jacr_view for i, jacr in enumerate(jacrs): jacr_view = jacr mj_jac{cap_type}(self._model.ptr, self.ptr, NULL, &jacr_view[0], i) return jacrs @property def {obj_type}_xvelp(self): jacp = self.{obj_type}_jacp.reshape((self._model.n{obj_type}, 3, self._model.nv)) xvelp = np.dot(jacp, self.qvel) return xvelp @property def {obj_type}_xvelr(self): jacr = self.{obj_type}_jacr.reshape((self._model.n{obj_type}, 3, self._model.nv)) xvelr = np.dot(jacr, self.qvel) return xvelr\n'''.format(obj_type=obj_type, cap_type=cap_type) return code def _set_body_identifiers(short_name, addr_name, long_name, obj_name): return (" self.{long_name}_names, self._{long_name}_name2id, self._{long_name}_id2name = " "self._extract_mj_names(p, p.name_{addr_name}adr, p.n{short_name}, mjtObj.mjOBJ_{obj_name})\n" ).format(short_name=short_name, long_name=long_name, obj_name=obj_name, addr_name=addr_name) def _add_getters(obj_type): return ''' def {obj_type}_id2name(self, id): if id not in self._{obj_type}_id2name: raise ValueError("No {obj_type} with id %d exists." % id) return self._{obj_type}_id2name[id] def {obj_type}_name2id(self, name): if name not in self._{obj_type}_name2id: raise ValueError("No \\"{obj_type}\\" with name %s exists. Available \\"{obj_type}\\" names = %s." % (name, self.{obj_type}_names)) return self._{obj_type}_name2id[name] '''.format(obj_type=obj_type) def get_const_from_define(full_src_lines): define_code = [] seen = set() for line in full_src_lines: define = "#define" if line.find(define) > -1: line = line[len(define):].strip() last_len = 100000 while last_len != len(line): last_len = len(line) line = line.replace(" ", " ") line = line.replace("\t", " ") comment = "" if line.find("//") > -1: line, comment = line.split("//") line, comment = line.strip(), comment.strip() if line.find(" ") > -1: var, val = line.split(" ") try: # In C/C++ numbers can have an 'f' suffix, specifying a single-precision number. # That is not supported by the Python floating point parser, therefore we need to strip that bit. if val[-1] == 'f': val = val[:-1] val = float(val) varname = var[2:] if varname in seen: print("Already seen {name}, skipping".format(name=varname)) continue seen.add(varname) new_line = varname + " = " + str(val) new_line += " " * (35 - len(new_line)) new_line += " # " + comment define_code.append(new_line) except Exception: traceback.print_exc() print("Couldn't parse line: %s" % line) return define_code def get_funcs(fname): src = subprocess.check_output([c_compiler, '-E', '-P', fname]).decode() src = src[src.find("int mj_activat"):] l = -1 while l != len(src): l = len(src) src = src.replace(" ", " ") src = src.replace("\t", " ") src = src.replace("\n", " ") src = src.replace("const ", "") src = src.replace(", ", ",") src = src.strip() funcs = src.split(";") funcs = [f.strip() for f in funcs if len(f) > 0] ret = "" count = 0 for f in funcs: ret_name = f.split(" ")[0] func_name = f.split(" ")[1].split("(")[0] args = f.split("(")[1][:-1] skip = False py_args_string = [] c_args_string = [] if args != "void": args = args.split(",") for arg in args: arg = arg.strip() data_type = " ".join(arg.split(" ")[:-1]) var_name = arg.split(" ")[-1] if var_name.find("[") > -1: #arr_size = var_name[var_name.find("[") + 1:var_name.find("]")] data_type = data_type + "*" var_name = var_name[:var_name.find("[")] # Some words are keywords in Python that are not keywords in C/C++, therefore they can be used as # variable identifiers. We need to handle these situations. if var_name in ['def']: var_name = '_' + var_name if data_type in ["char*"]: py_args_string.append("str " + var_name) c_args_string.append(var_name + ".encode()") continue if data_type in ["unsigned char"]: skip = True break if data_type == "mjtNum": py_args_string.append("float " + var_name) c_args_string.append(var_name) continue if data_type == "mjtNum*": py_args_string.append( "np.ndarray[np.float64_t, mode=\"c\", ndim=1] " + var_name) c_args_string.append("&%s[0]" % var_name) continue if data_type == "mjtByte": py_args_string.append("int " + var_name) c_args_string.append(var_name) continue if data_type == "mjtByte*": py_args_string.append( "np.ndarray[np.uint8_t, mode=\"c\", ndim=1] " + var_name) c_args_string.append("&%s[0]" % var_name) continue if data_type[:2] == "mj" and data_type[-1] == "*": py_args_string.append( "PyMj" + data_type[2:-1] + " " + var_name) c_args_string.append(var_name + ".ptr") continue if data_type[:2] == 'mj' and '*' not in data_type: py_args_string.append( "PyMj" + data_type[2:] + " " + var_name) c_args_string.append(var_name + ".ptr[0]") # dereference continue if data_type in "int": py_args_string.append("int " + var_name) c_args_string.append(var_name) continue if data_type in "int*": py_args_string.append("uintptr_t " + var_name) c_args_string.append("<int*>" + var_name) continue # XXX skip = True if not skip and ((ret_name in ["int", "mjtNum", "void"]) or (ret_name[:2] == "mj" and ret_name[-1] == "*") and ret_name != "mjtNum*" and ret_name != "mjData*"): code = "def _%s(%s):\n" % (func_name, ", ".join(py_args_string)) ret_val = "%s(%s)" % (func_name, ", ".join(c_args_string)) code += " " if ret_name in ["int", "mjtNum"]: code += "return " + ret_val elif ret_name == "void": code += ret_val elif ret_name[:2] == "mj": code += "return WrapMj" + ret_name[2:-1] + "(" + ret_val + ")" else: import ipdb ipdb.set_trace() ret += code + "\n\n" count += 1 print(ret) print("Generated %d out of %d" % (count, len(funcs))) return ret def main(): HEADER_DIR = os.path.expanduser(os.path.join('~', '.mujoco', 'mujoco210', 'include')) HEADER_FILES = [ 'mjmodel.h', 'mjdata.h', 'mjvisualize.h', 'mjrender.h', 'mjui.h' ] if len(sys.argv) > 1: OUTPUT = sys.argv[1] else: OUTPUT = os.path.join('mujoco_py', 'generated', 'wrappers.pxi') OUTPUT_CONST = os.path.join('mujoco_py', 'generated', 'const.py') funcs = get_funcs(os.path.join(HEADER_DIR, "mujoco.h")) full_src_lines = get_full_scr_lines(HEADER_DIR, HEADER_FILES) array_shapes = get_array_shapes(full_src_lines) processed_src = get_processed_src(HEADER_DIR, full_src_lines) struct_dict = get_full_struct_dict(processed_src, array_shapes) structname2wrappername, structname2wrapfuncname = get_struct_wrapper(struct_dict) define_const = get_const_from_define(full_src_lines) enum_const = get_const_from_enum(processed_src) const_code = "# Automatically generated. Do not modify!\n\n###### const from defines ######\n" const_code += "\n".join(define_const) const_code += "\n\n###### const from enums ######\n\n" const_code += "\n".join(enum_const) with open(OUTPUT_CONST, 'w') as f: f.write(const_code) code = [] needed_1d_wrappers = set() needed_2d_wrappers = set() # ===== Generate wrapper extension classes ===== for name, fields in struct_dict.items(): member_decls, member_initializers, member_getters = [], [], [] model_var_name = 'p' if name == 'mjModel' else 'model' # Disabling a few accessors that are unsafe due to ambiguous meaning. REPLACEMENT_BY_ORIGINAL = OrderedDict([ ('xpos', 'body_xpos'), ('xmat', 'body_xmat'), ('xquat', 'body_xquat'), ('efc_pos', 'active_contacts_efc_pos'), ]) for scalar_name, scalar_type in fields['scalars']: if scalar_type in ['float', 'int', 'mjtNum', 'mjtByte', 'unsigned int']: member_getters.append( ' @property\n def {name}(self): return self.ptr.{name}'.format(name=scalar_name)) member_getters.append(' @{name}.setter\n def {name}(self, {type} x): self.ptr.{name} = x'.format( name=scalar_name, type=scalar_type)) elif scalar_type in struct_dict: # This is a struct member member_decls.append(' cdef {} _{}'.format( structname2wrappername[scalar_type], scalar_name)) member_initializers.append(' self._{scalar_name} = {wrap_func_name}(&p.{scalar_name}{model_arg})'.format( scalar_name=scalar_name, wrap_func_name=structname2wrapfuncname[scalar_type], model_arg=( ', ' + model_var_name) if struct_dict[scalar_type]['depends_on_model'] else '' )) member_getters.append( ' @property\n def {name}(self): return self._{name}'.format(name=scalar_name)) else: print('Warning: skipping {} {}.{}'.format( scalar_type, name, scalar_name)) # Pointer types for ptr_name, ptr_type, (shape0, shape1) in fields['ptrs']: if ptr_type in struct_dict: assert shape0.startswith('n') and shape1 == 1 member_decls.append(' cdef tuple _{}'.format(ptr_name)) member_initializers.append( ' self._{ptr_name} = tuple([{wrap_func_name}(&p.{ptr_name}[i]{model_arg}) for i in range({size0})])'.format( ptr_name=ptr_name, wrap_func_name=structname2wrapfuncname[ptr_type], size0='{}.{}'.format(model_var_name, shape0), model_arg=( ', ' + model_var_name) if struct_dict[ptr_type]['depends_on_model'] else '' )) else: assert name == 'mjModel' or fields['depends_on_model'] member_decls.append(' cdef np.ndarray _{}'.format(ptr_name)) if shape0 == 1 or shape1 == 1: # Collapse to 1d for the user's convenience size0 = shape1 if shape0 == 1 else shape0 member_initializers.append( ' self._{ptr_name} = _wrap_{ptr_type}_1d(p.{ptr_name}, {size0})'.format( ptr_name=ptr_name, ptr_type=ptr_type.replace(' ', '_'), size0=format_size_argument(model_var_name, size0), )) else: member_initializers.append( ' self._{ptr_name} = _wrap_{ptr_type}_2d(p.{ptr_name}, {size0}, {size1})'.format( ptr_name=ptr_name, ptr_type=ptr_type.replace(' ', '_'), size0=format_size_argument(model_var_name, shape0), size1=format_size_argument(model_var_name, shape1), )) needed_2d_wrappers.add(ptr_type) if ptr_name in REPLACEMENT_BY_ORIGINAL: member_getters.append(""" @property def {name}(self): raise RuntimeError("{replacement} should be used instead of {name}")\n""".format( name=ptr_name, replacement=REPLACEMENT_BY_ORIGINAL[ptr_name])) else: member_getters.append( ' @property\n def {name}(self): return self._{name}'.format(name=ptr_name)) # Array types: handle the same way as pointers for array_name, array_type, array_size in fields['arrays']: if array_type in struct_dict: # This is a struct member member_decls.append(' cdef list _{}'.format(array_name)) member_initializers.append(' self._{array_name} = [{wrap_func_name}(&p.{array_name}{model_arg}[i]) for i in range({array_size})]'.format( array_name=array_name, array_size=array_size, wrap_func_name=structname2wrapfuncname[array_type], model_arg=( ', ' + model_var_name) if struct_dict[array_type]['depends_on_model'] else '' )) member_getters.append( ' @property\n def {name}(self): return self._{name}'.format(name=array_name)) else: member_decls.append( ' cdef np.ndarray _{}'.format(array_name)) member_initializers.append( ' self._{array_name} = _wrap_{array_type}_1d(&p.{array_name}[0], {size})'.format( array_name=array_name, array_type=array_type.replace(' ', '_'), size=array_size, )) member_getters.append( ' @property\n def {name}(self): return self._{name}'.format(name=array_name)) needed_1d_wrappers.add(array_type) # 2D-Array types: handle the same way as pointers for array_name, array_type, array_size in fields['arrays2d']: if array_type in struct_dict: print("Skipping 2d array of structs {name}.{arr_name}: <{arr_type}[:{arr_size0},:{arr_size1}]>".format( name=name, arr_name=array_name, arr_type=array_type, arr_size0=array_size[0], arr_size1=array_size[1]) ) continue else: member_decls.append( ' cdef np.ndarray _{}'.format(array_name)) member_initializers.append( ' self._{array_name} = _wrap_{array_type}_2d(&p.{array_name}[0][0], {size0}, {size1})'.format( array_name=array_name, array_type=array_type.replace(' ', '_'), size0=array_size[0], size1=array_size[1], )) member_getters.append( ' @property\n def {name}(self): return self._{name}'.format(name=array_name)) needed_2d_wrappers.add(array_type) member_getters = '\n'.join(member_getters) member_decls = '\n' + '\n'.join(member_decls) if member_decls else '' member_initializers = '\n' + \ '\n'.join(member_initializers) if member_initializers else '' model_decl = '\n cdef PyMjModel _model' if fields[ 'depends_on_model'] else '' model_param = ', PyMjModel model' if fields['depends_on_model'] else '' model_setter = 'self._model = model' if fields[ 'depends_on_model'] else '' model_arg = ', model' if fields['depends_on_model'] else '' if name == "mjModel": extra = '\n' obj_types = ['body', 'joint', 'geom', 'site', 'light', 'camera', 'actuator', 'sensor', 'tendon', 'mesh'] obj_types_names = [o + '_names' for o in obj_types] extra += ' cdef readonly tuple ' + ', '.join(obj_types_names) + '\n' obj_types_id2names = ['_' + o + '_id2name' for o in obj_types] extra += ' cdef readonly dict ' + ', '.join(obj_types_id2names) + '\n' obj_types_name2ids = ['_' + o + '_name2id' for o in obj_types] extra += ' cdef readonly dict ' + ', '.join(obj_types_name2ids) + '\n' for obj_type in obj_types: extra += _add_getters(obj_type) # Note: named userdata fields are not present in MuJoCo, # they're special accessors we add in mujoco-py. # So these fields need to be python accessible instead of readonly. extra += ' cdef public tuple userdata_names\n' extra += ' cdef public dict _userdata_id2name\n' extra += ' cdef public dict _userdata_name2id\n' extra += _add_getters('userdata') extra += ''' cdef inline tuple _extract_mj_names(self, mjModel* p, int*name_adr, int n, mjtObj obj_type): cdef char *name cdef int obj_id # objects don't need to be named in the XML, so name might be None id2name = {i: None for i in range(n)} name2id = {} for i in range(n): name = p.names + name_adr[i] decoded_name = name.decode() if decoded_name: obj_id = mj_name2id(p, obj_type, name) assert 0 <= obj_id < n and id2name[obj_id] is None name2id[decoded_name] = obj_id id2name[obj_id] = decoded_name # sort names by increasing id to keep order deterministic return tuple(id2name[id] for id in sorted(name2id.values())), name2id, id2name def get_xml(self): cdef char errstr[300] cdef int ret with TemporaryDirectory() as td: filename = os.path.join(td, 'model.xml') with wrap_mujoco_warning(): ret = mj_saveLastXML(filename.encode(), self.ptr, errstr, 300) if ret == 0: raise Exception('Failed to save XML: {}'.format(errstr)) return open(filename).read() def get_mjb(self): with TemporaryDirectory() as td: filename = os.path.join(td, 'model.mjb') with wrap_mujoco_warning(): mj_saveModel(self.ptr, filename.encode(), NULL, 0) return open(filename, 'rb').read() def set_userdata_names(self, userdata_names): assert isinstance(userdata_names, (list, tuple)), 'bad userdata names' assert len(userdata_names) <= self.nuserdata, 'insufficient userdata' self.userdata_names = tuple(userdata_names) self._userdata_id2name = dict() self._userdata_name2id = dict() for i, name in enumerate(userdata_names): assert isinstance(name, str), 'names must all be strings' self._userdata_id2name[i] = name self._userdata_name2id[name] = i def __dealloc__(self): mj_deleteModel(self.ptr) ''' extra_set = '\n' # MuJoCo isn't very consistent in how it uses long and # abbreviated names :( extra_set += _set_body_identifiers('body', 'body', 'body', 'BODY') extra_set += _set_body_identifiers('jnt', 'jnt', 'joint', 'JOINT') extra_set += _set_body_identifiers('geom', 'geom', 'geom', 'GEOM') extra_set += _set_body_identifiers('site', 'site', 'site', 'SITE') extra_set += _set_body_identifiers('light', 'light', 'light', 'LIGHT') extra_set += _set_body_identifiers('cam', 'cam', 'camera', 'CAMERA') extra_set += _set_body_identifiers('u', 'actuator', 'actuator', 'ACTUATOR') extra_set += _set_body_identifiers('sensor', 'sensor', 'sensor', 'SENSOR') extra_set += _set_body_identifiers('tendon', 'tendon', 'tendon', 'TENDON') extra_set += _set_body_identifiers('mesh', 'mesh', 'mesh', 'MESH') # userdata_names is empty at construction time extra_set += ' self.userdata_names = tuple()\n' extra_set += ' self._userdata_name2id = dict()\n' extra_set += ' self._userdata_id2name = dict()\n' for q_type in ('pos', 'vel'): # Position dimensionality and degrees of freedom are different # for free and ball joints. if q_type == 'pos': adr_name, free_ndim, ball_ndim = 'qpos', 7, 4 else: adr_name, free_ndim, ball_ndim = 'dof', 6, 3 extra += """ def get_joint_q{q_type}_addr(self, name): ''' Returns the q{q_type} address for given joint. Returns: - address (int, tuple): returns int address if 1-dim joint, otherwise returns the a (start, end) tuple for {q_type}[start:end] access. ''' joint_id = self.joint_name2id(name) joint_type = self.jnt_type[joint_id] joint_addr = self.jnt_{adr_name}adr[joint_id] if joint_type == mjtJoint.mjJNT_FREE: ndim = {free_ndim} elif joint_type == mjtJoint.mjJNT_BALL: ndim = {ball_ndim} else: assert joint_type in (mjtJoint.mjJNT_HINGE, mjtJoint.mjJNT_SLIDE) ndim = 1 if ndim == 1: return joint_addr else: return (joint_addr, joint_addr + ndim)\n""".format( q_type=q_type, adr_name=adr_name, free_ndim=free_ndim, ball_ndim=ball_ndim) elif name == "mjData": extra = ''' @property def body_xpos(self): return self._xpos @property def body_xquat(self): return self._xquat @property def body_xmat(self): return self._xmat @property def active_contacts_efc_pos(self): return self._efc_pos[self.ne:self.nefc] def __dealloc__(self): mj_deleteData(self.ptr) ''' extra += _add_named_access_methods('body', 'xpos', 'xpos') extra += _add_named_access_methods('body', 'xquat', 'xquat') extra += _add_named_access_methods('body', 'xmat', 'xmat') extra += _add_named_access_methods('body', 'xipos', 'xipos') extra += _add_named_access_methods('body', 'ximat', 'ximat') extra += _add_named_jacobian_methods('body') member_getters += _add_jacobian_getters('body') extra += _add_named_access_methods('joint', 'xanchor', 'xanchor') extra += _add_named_access_methods('joint', 'xaxis', 'xaxis') extra += _add_named_access_methods('geom', 'geom_xpos', 'xpos') extra += _add_named_access_methods('geom', 'geom_xmat', 'xmat') extra += _add_named_jacobian_methods('geom') member_getters += _add_jacobian_getters('geom') extra += _add_named_access_methods('site', 'site_xpos', 'xpos') extra += _add_named_access_methods('site', 'site_xmat', 'xmat') extra += _add_named_jacobian_methods('site') member_getters += _add_jacobian_getters('site') extra += _add_named_access_methods('camera', 'cam_xpos', 'xpos') extra += _add_named_access_methods('camera', 'cam_xmat', 'xmat') extra += _add_named_access_methods('light', 'light_xpos', 'xpos') extra += _add_named_access_methods('light', 'light_xdir', 'xdir') extra += _add_named_access_methods('sensor', 'sensordata', None) extra += _add_named_access_methods('userdata', 'userdata', None) for pose_type in ('pos', 'quat'): extra += """ def get_mocap_{pose_type}(self, name): body_id = self._model.body_name2id(name) mocap_id = self._model.body_mocapid[body_id] return self.mocap_{pose_type}[mocap_id] def set_mocap_{pose_type}(self, name, value): body_id = self._model.body_name2id(name) mocap_id = self._model.body_mocapid[body_id] self.mocap_{pose_type}[mocap_id] = value\n""".format( pose_type=pose_type) for q_type in ('pos', 'vel'): extra += """ def get_joint_q{q_type}(self, name): addr = self._model.get_joint_q{q_type}_addr(name) if isinstance(addr, (int, np.int32, np.int64)): return self.q{q_type}[addr] else: start_i, end_i = addr return self.q{q_type}[start_i:end_i] def set_joint_q{q_type}(self, name, value): addr = self._model.get_joint_q{q_type}_addr(name) if isinstance(addr, (int, np.int32, np.int64)): self.q{q_type}[addr] = value else: start_i, end_i = addr value = np.array(value) assert value.shape == (end_i - start_i,), ( "Value has incorrect shape %s: %s" % (name, value)) self.q{q_type}[start_i:end_i] = value\n""".format( q_type=q_type) extra_set = "" elif name in ["mjVFS", "mjrRect"]: extra = ''' def __cinit__(self): self.ptr = <{name}*> PyMem_Malloc(sizeof({name})) if not self.ptr: raise MemoryError() def __dealloc__(self): PyMem_Free(self.ptr) '''.format(name=name) extra_set = '' elif name in [ 'mjuiItemSingle', 'mjuiItemMulti', 'mjuiItemSlider', 'mjuiItemEdit' ]: # these structs don't have a corresponding typedef. continue elif name[:2] == 'mj': extra = ''' def __cinit__(self): self.ptr = NULL ''' extra_set = '' else: extra = "" extra_set = "" code.append(''' cdef class {wrapper_name}(object): cdef {struct_name}* ptr {model_decl} {member_decls} {extra} @property def uintptr(self): return <uintptr_t>self.ptr cdef void _set(self, {struct_name}* p{model_param}): {extra_set} self.ptr = p {model_setter} {member_initializers} \n{member_getters} cdef {wrapper_name} {wrap_func_name}({struct_name}* p{model_param}): cdef {wrapper_name} o = {wrapper_name}() o._set(p{model_arg}) return o '''.format( wrapper_name=structname2wrappername[name], extra=extra, extra_set=extra_set, struct_name=name, wrap_func_name=structname2wrapfuncname[name], model_decl=model_decl, model_param=model_param, model_setter=model_setter, model_arg=model_arg, member_decls=member_decls, member_initializers=member_initializers, member_getters=member_getters, ).strip()) # ===== Generate array-to-NumPy wrappers ===== # TODO: instead of returning None for empty arrays, instead return NumPy arrays with the appropriate shape and type # The only reason we're not doing this already is that cython's views don't work with 0-length axes, # even though NumPy does. # TODO: set NumPy array type explicitly (e.g. char will be viewed # incorrectly as np.int64) for type_name in sorted(needed_1d_wrappers): code.append(''' cdef inline np.ndarray _wrap_{type_name_nospaces}_1d({type_name}* a, int shape0): if shape0 == 0: return None cdef {type_name}[:] b = <{type_name}[:shape0]> a return np.asarray(b) '''.format(type_name_nospaces=type_name.replace(' ', '_'), type_name=type_name).strip()) for type_name in sorted(needed_2d_wrappers): code.append(''' cdef inline np.ndarray _wrap_{type_name_nospaces}_2d({type_name}* a, int shape0, int shape1): if shape0 * shape1 == 0: return None cdef {type_name}[:,:] b = <{type_name}[:shape0,:shape1]> a return np.asarray(b) '''.format(type_name_nospaces=type_name.replace(' ', '_'), type_name=type_name).strip()) header = '''# cython: language_level=3 # Automatically generated. Do not modify! include "../pxd/mujoco.pxd" from libc.stdint cimport uintptr_t from cpython.mem cimport PyMem_Malloc, PyMem_Free cimport numpy as np import numpy as np from tempfile import TemporaryDirectory ''' code.append(funcs) code = header + '\n\n'.join(code) + '\n' print(len(code.splitlines())) with open(OUTPUT, 'w') as f: f.write(code) if __name__ == "__main__": main()