mujoco_worldgen/util/envs/flexible_load.py (79 lines of code) (raw):

import os import numpy as np import json import _jsonnet from os.path import join from collections import OrderedDict from glob import glob from mujoco_py import load_model_from_xml, load_model_from_mjb, MjSim from runpy import run_path from mujoco_worldgen import Env from mujoco_worldgen.util.path import worldgen_path from mujoco_worldgen.util.types import extract_matching_arguments from mujoco_worldgen.parser import parse_file, unparse_dict def get_function(fn_data): name = fn_data['function'] extra_args = fn_data['args'] module_path, function_name = name.rsplit(':', 1) result = getattr(__import__(module_path, fromlist=(function_name,)), function_name) if len(extra_args) > 0: def result_wrapper(*args, **kwargs): actual_kwargs = extra_args.copy() actual_kwargs.update(kwargs) return result(*args, **actual_kwargs) return result_wrapper else: return result def load_env(pattern, core_dir=worldgen_path(), envs_dir='examples', xmls_dir='xmls', return_args_remaining=False, **kwargs): """ Flexible load of an environment based on `pattern`. Passes args to make_env(). :param pattern: tries to match environment to the pattern. :param core_dir: Absolute path to the core code directory for the project containing the environments we want to examine. This is usually the top-level git repository folder - in the case of the mujoco-worldgen repo, it would be the 'mujoco-worldgen' folder. :param envs_dir: relative path (from core_dir) to folder containing all environment files. :param xmls_dir: relative path (from core_dir) to folder containing all xml files. :param return_remaining_kwargs: returns arguments from kwargs that are not used. :param kwargs: arguments passed to the environment function. :return: mujoco_worldgen.Env """ # Loads environment based on XML. env = None args_remaining = {} if pattern.endswith(".xml"): if len(kwargs) > 0: print("Not passing any argument to environment, " "because environment is loaded from XML. XML doesn't " "accept any extra input arguments") def get_sim(seed): model = load_model_from_path_fix_paths(xml_path=pattern) return MjSim(model) env = Env(get_sim=get_sim) # Loads environment based on mjb. elif pattern.endswith(".mjb"): if len(kwargs) != 0: print("Not passing any argument to environment, " "because environment is loaded from MJB. MJB doesn't " "accept any extra input arguments") def get_sim(seed): model = load_model_from_mjb(pattern) return MjSim(model) env = Env(get_sim=get_sim) # Loads environment from a python file elif pattern.endswith("py") and os.path.exists(pattern): print("Loading env from the module: %s" % pattern) module = run_path(pattern) make_env = module["make_env"] args_to_pass, args_remaining = extract_matching_arguments(make_env, kwargs) env = make_env(**args_to_pass) elif pattern.endswith(".jsonnet") and os.path.exists(pattern): env_data = json.loads(_jsonnet.evaluate_file(pattern)) make_env = get_function(env_data['make_env']) args_to_pass, args_remaining = extract_matching_arguments(make_env, kwargs) env = make_env(**args_to_pass) else: # If couldn't load based on easy search, then look # into predefined subdirectories. matching = (glob(join(core_dir, envs_dir, "**", "*.py"), recursive=True) + glob(join(core_dir, xmls_dir, "**", "*.xml"), recursive=True)) matching = [match for match in matching if match.find(pattern) > -1] matching = [match for match in matching if not os.path.basename(match).startswith('test_')] assert len(matching) < 2, "Found multiple environments matching %s" % str(matching) if len(matching) == 1: return load_env(matching[0], return_args_remaining=return_args_remaining, **kwargs) if return_args_remaining: return env, args_remaining else: return env def load_model_from_path_fix_paths(xml_path, zero_gravity=True): """ Loads model from XML path. Ensures that all assets are locally available. If needed might rename paths. :param xml_path: path to xml file :param zero_gravity: if true, zero gravity in model """ xml_dict = parse_file(xml_path, enforce_validation=False) if zero_gravity: # zero gravity so that the object doesn't fall down option = xml_dict.setdefault('option', OrderedDict()) option['@gravity'] = np.zeros(3) xml = unparse_dict(xml_dict) model = load_model_from_xml(xml) return model