in scripts/gen_wrappers.py [0:0]
def main():
HEADER_DIR = os.path.expanduser(os.path.join('~', '.mujoco', 'mujoco210', 'include'))
HEADER_FILES = [
'mjmodel.h',
'mjdata.h',
'mjvisualize.h',
'mjrender.h',
'mjui.h'
]
if len(sys.argv) > 1:
OUTPUT = sys.argv[1]
else:
OUTPUT = os.path.join('mujoco_py', 'generated', 'wrappers.pxi')
OUTPUT_CONST = os.path.join('mujoco_py', 'generated', 'const.py')
funcs = get_funcs(os.path.join(HEADER_DIR, "mujoco.h"))
full_src_lines = get_full_scr_lines(HEADER_DIR, HEADER_FILES)
array_shapes = get_array_shapes(full_src_lines)
processed_src = get_processed_src(HEADER_DIR, full_src_lines)
struct_dict = get_full_struct_dict(processed_src, array_shapes)
structname2wrappername, structname2wrapfuncname = get_struct_wrapper(struct_dict)
define_const = get_const_from_define(full_src_lines)
enum_const = get_const_from_enum(processed_src)
const_code = "# Automatically generated. Do not modify!\n\n###### const from defines ######\n"
const_code += "\n".join(define_const)
const_code += "\n\n###### const from enums ######\n\n"
const_code += "\n".join(enum_const)
with open(OUTPUT_CONST, 'w') as f:
f.write(const_code)
code = []
needed_1d_wrappers = set()
needed_2d_wrappers = set()
# ===== Generate wrapper extension classes =====
for name, fields in struct_dict.items():
member_decls, member_initializers, member_getters = [], [], []
model_var_name = 'p' if name == 'mjModel' else 'model'
# Disabling a few accessors that are unsafe due to ambiguous meaning.
REPLACEMENT_BY_ORIGINAL = OrderedDict([
('xpos', 'body_xpos'),
('xmat', 'body_xmat'),
('xquat', 'body_xquat'),
('efc_pos', 'active_contacts_efc_pos'),
])
for scalar_name, scalar_type in fields['scalars']:
if scalar_type in ['float', 'int', 'mjtNum', 'mjtByte', 'unsigned int']:
member_getters.append(
' @property\n def {name}(self): return self.ptr.{name}'.format(name=scalar_name))
member_getters.append(' @{name}.setter\n def {name}(self, {type} x): self.ptr.{name} = x'.format(
name=scalar_name, type=scalar_type))
elif scalar_type in struct_dict:
# This is a struct member
member_decls.append(' cdef {} _{}'.format(
structname2wrappername[scalar_type], scalar_name))
member_initializers.append(' self._{scalar_name} = {wrap_func_name}(&p.{scalar_name}{model_arg})'.format(
scalar_name=scalar_name,
wrap_func_name=structname2wrapfuncname[scalar_type],
model_arg=(
', ' + model_var_name) if struct_dict[scalar_type]['depends_on_model'] else ''
))
member_getters.append(
' @property\n def {name}(self): return self._{name}'.format(name=scalar_name))
else:
print('Warning: skipping {} {}.{}'.format(
scalar_type, name, scalar_name))
# Pointer types
for ptr_name, ptr_type, (shape0, shape1) in fields['ptrs']:
if ptr_type in struct_dict:
assert shape0.startswith('n') and shape1 == 1
member_decls.append(' cdef tuple _{}'.format(ptr_name))
member_initializers.append(
' self._{ptr_name} = tuple([{wrap_func_name}(&p.{ptr_name}[i]{model_arg}) for i in range({size0})])'.format(
ptr_name=ptr_name,
wrap_func_name=structname2wrapfuncname[ptr_type],
size0='{}.{}'.format(model_var_name, shape0),
model_arg=(
', ' + model_var_name) if struct_dict[ptr_type]['depends_on_model'] else ''
))
else:
assert name == 'mjModel' or fields['depends_on_model']
member_decls.append(' cdef np.ndarray _{}'.format(ptr_name))
if shape0 == 1 or shape1 == 1:
# Collapse to 1d for the user's convenience
size0 = shape1 if shape0 == 1 else shape0
member_initializers.append(
' self._{ptr_name} = _wrap_{ptr_type}_1d(p.{ptr_name}, {size0})'.format(
ptr_name=ptr_name,
ptr_type=ptr_type.replace(' ', '_'),
size0=format_size_argument(model_var_name, size0),
))
else:
member_initializers.append(
' self._{ptr_name} = _wrap_{ptr_type}_2d(p.{ptr_name}, {size0}, {size1})'.format(
ptr_name=ptr_name,
ptr_type=ptr_type.replace(' ', '_'),
size0=format_size_argument(model_var_name, shape0),
size1=format_size_argument(model_var_name, shape1),
))
needed_2d_wrappers.add(ptr_type)
if ptr_name in REPLACEMENT_BY_ORIGINAL:
member_getters.append("""