in mujoco_worldgen/util/envs/flexible_load.py [0:0]
def load_env(pattern, core_dir=worldgen_path(), envs_dir='examples', xmls_dir='xmls',
return_args_remaining=False, **kwargs):
"""
Flexible load of an environment based on `pattern`.
Passes args to make_env().
:param pattern: tries to match environment to the pattern.
:param core_dir: Absolute path to the core code directory for the project containing
the environments we want to examine. This is usually the top-level git repository
folder - in the case of the mujoco-worldgen repo, it would be the 'mujoco-worldgen'
folder.
:param envs_dir: relative path (from core_dir) to folder containing all environment files.
:param xmls_dir: relative path (from core_dir) to folder containing all xml files.
:param return_remaining_kwargs: returns arguments from kwargs that are not used.
:param kwargs: arguments passed to the environment function.
:return: mujoco_worldgen.Env
"""
# Loads environment based on XML.
env = None
args_remaining = {}
if pattern.endswith(".xml"):
if len(kwargs) > 0:
print("Not passing any argument to environment, "
"because environment is loaded from XML. XML doesn't "
"accept any extra input arguments")
def get_sim(seed):
model = load_model_from_path_fix_paths(xml_path=pattern)
return MjSim(model)
env = Env(get_sim=get_sim)
# Loads environment based on mjb.
elif pattern.endswith(".mjb"):
if len(kwargs) != 0:
print("Not passing any argument to environment, "
"because environment is loaded from MJB. MJB doesn't "
"accept any extra input arguments")
def get_sim(seed):
model = load_model_from_mjb(pattern)
return MjSim(model)
env = Env(get_sim=get_sim)
# Loads environment from a python file
elif pattern.endswith("py") and os.path.exists(pattern):
print("Loading env from the module: %s" % pattern)
module = run_path(pattern)
make_env = module["make_env"]
args_to_pass, args_remaining = extract_matching_arguments(make_env, kwargs)
env = make_env(**args_to_pass)
elif pattern.endswith(".jsonnet") and os.path.exists(pattern):
env_data = json.loads(_jsonnet.evaluate_file(pattern))
make_env = get_function(env_data['make_env'])
args_to_pass, args_remaining = extract_matching_arguments(make_env, kwargs)
env = make_env(**args_to_pass)
else:
# If couldn't load based on easy search, then look
# into predefined subdirectories.
matching = (glob(join(core_dir, envs_dir, "**", "*.py"), recursive=True) +
glob(join(core_dir, xmls_dir, "**", "*.xml"), recursive=True))
matching = [match for match in matching if match.find(pattern) > -1]
matching = [match for match in matching if not os.path.basename(match).startswith('test_')]
assert len(matching) < 2, "Found multiple environments matching %s" % str(matching)
if len(matching) == 1:
return load_env(matching[0], return_args_remaining=return_args_remaining, **kwargs)
if return_args_remaining:
return env, args_remaining
else:
return env