mujoco_worldgen/parser/normalize.py (187 lines of code) (raw):
from collections import OrderedDict
from mujoco_worldgen.util.types import accepts, returns
from mujoco_worldgen.parser.const import list_types, float_arg_types
import numpy as np
from decimal import Decimal, getcontext
import ast
import re
getcontext().prec = 10
'''
This methods are used internally by parser.py
Internal notes:
normalize() - in-place normalizes and converts an xml dictionary
see docstring for notes about what types are converted
stringify() - in-place de-normalizes, and converts all values to strings
see docstring for notes
normalize_*() - return normal forms of values (numbers, vectors, etc)
raise exception if input value cannot be converted
'''
@accepts(OrderedDict)
def normalize(xml_dict):
'''
The starting point is a dictionary of the form returned by xmltodict.
See that module's documentation here:
https://github.com/martinblech/xmltodict
Normalize a mujoco model XML:
- some nodes have OrderDict value (mostly top-lever such as worldbody)
some nodes have list values (mostly lower level). Check const.py
for more information.
- parameters ('@name', etc) never have list or OrderedDict values
- "true" and "false" are converted to bool()
- numbers are converted to floats
- vectors are converted to np.ndarray()
Note: stringify() is the opposite of this, and converts everything back
into strings in preparation for unparse_dict().
'''
# As a legacy, previously many of our XMLs had an unused model name.
# This removes it (as part of annotate) and can be phased out eventually.
if '@model' in xml_dict:
del xml_dict['@model']
for key, value in xml_dict.items():
if isinstance(value, OrderedDict):
# There is one exception.
# <default> symbol occurs twice.
# Once as OrderDict (top-level), once as list (lower-level).
if key == "default":
if "@class" in value:
xml_dict[key] = [value]
elif key in list_types:
xml_dict[key] = [value]
normalize(value)
continue
if isinstance(value, list):
for child in value:
normalize(child)
continue
if isinstance(value, str):
xml_dict[key] = normalize_value(value)
# sometimes data is stored as int when it's float.
# We make a conversion here.
if key in float_arg_types:
if isinstance(xml_dict[key], int):
xml_dict[key] = float(xml_dict[key])
elif isinstance(xml_dict[key], np.ndarray):
xml_dict[key] = xml_dict[key].astype(np.float64)
@accepts((int, float, np.float32, np.float64, np.int64))
@returns(str)
def num2str(num):
ret = "%g" % Decimal("%.6f" % num)
if ret == "-0":
return "0"
else:
return ret
@accepts((np.ndarray, tuple, list))
@returns(str)
def vec2str(vec):
return " ".join([num2str(v) for v in vec])
@returns(bool)
def is_normalizeable(normalize_function, value):
'''
Wraps a normalize_*() function, and returns True if value can be
normalized by normalize_function, otherwise returns False.
'''
try:
normalize_function(value)
return True
except:
return False
def normalize_numeric(value):
''' Normalize a numeric value into a float. '''
if isinstance(value, (float, int, np.float64, np.int64)):
return value
if isinstance(value, (str, bytes)):
f = float(value)
if f == int(f): # preferentially return integers if equal
return int(f)
return f
raise ValueError('Cannot convert {} to numeric'.format(value))
@accepts((np.ndarray, list, tuple, str))
def normalize_vector(value):
''' Normalize a vector value to a np.ndarray(). '''
if isinstance(value, np.ndarray):
return value
if (isinstance(value, (list, tuple)) and len(value) > 0 and
is_normalizeable(normalize_numeric, value[0])):
return np.array(value)
if isinstance(value, str):
# Split on spaces, filter empty, convert to numpy array
if "," in value or re.search("\[.*\]", value) is not None:
return np.array(ast.literal_eval(value))
else:
split = value.split()
return np.array([normalize_numeric(v) for v in split])
raise ValueError('Cannot convert {} to vector'.format(value))
def normalize_boolean(value):
''' Normalize a boolean value to a bool(). '''
if isinstance(value, bool):
return value
if isinstance(value, str):
if value.lower().strip() == 'true':
return True
if value.lower().strip() == 'false':
return False
raise ValueError('Cannot convert {} to boolean'.format(value))
def normalize_none(value):
''' Normalize a none string value to a None. '''
if isinstance(value, None.__class__):
return value
if isinstance(value, str):
if value.lower().strip() == 'none':
return None
raise ValueError('Cannot convert {} to None'.format(value))
def normalize_string(value):
''' Normalize a string value. '''
if isinstance(value, bytes):
value = value.decode()
if isinstance(value, str):
return value.strip()
raise ValueError('Cannot convert {} to string'.format(value))
def normalize_value(value):
''' Return the normalized version of a value by trying normalize_*(). '''
if value is None:
return None
for normalizer in (normalize_numeric,
normalize_vector,
normalize_none,
normalize_boolean,
normalize_string):
try:
return normalizer(value)
except:
continue
raise ValueError('Cannot normalize {}: {}'.format(type(value), value))
@accepts((OrderedDict, list))
def stringify(xml_dict):
'''
De-normalize xml dictionary (or list), converting all pythonic values (arrays, bools)
into strings that will be used in the final XML.
This is the opposite of normalize().
'''
if isinstance(xml_dict, OrderedDict):
enumeration = list(xml_dict.items())
elif isinstance(xml_dict, list):
enumeration = enumerate(xml_dict)
for key, value in enumeration:
# Handle a list of nodes to stringify
if isinstance(value, list):
if len(value) == 0:
del xml_dict[key]
else:
if sum([isinstance(v, (int, float, np.float32, np.int)) for v in value]) == len(value):
xml_dict[key] = vec2str(value)
else:
stringify(value)
elif isinstance(value, OrderedDict):
stringify(value)
elif isinstance(value, (np.ndarray, tuple)):
xml_dict[key] = vec2str(value)
elif isinstance(value, float):
xml_dict[key] = num2str(value) # format with fixed decimal places
elif isinstance(value, bool): # MUST COME BEFORE int() CHECK
xml_dict[key] = str(value).lower() # True -> 'true', etc.
elif isinstance(value, int): # isinstance(True, int) -> True. SAD!
xml_dict[key] = str(value) # Format without decimal places
elif isinstance(value, str):
pass # Value is already fine
elif value is None:
pass
else:
raise ValueError(
'Bad type for key {}: {}'.format(key, type(value)))