mujoco_py/mjsimstate.pyx (70 lines of code) (raw):
class MjSimState(namedtuple('SimStateBase', 'time qpos qvel act udd_state')):
"""Represents a snapshot of the simulator's state.
This includes time, qpos, qvel, act, and udd_state.
"""
__slots__ = ()
# need to implement this because numpy doesn't support == on arrays
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
if set(self.udd_state.keys()) != set(other.udd_state.keys()):
return False
for k in self.udd_state.keys():
if isinstance(self.udd_state[k], Number) and self.udd_state[k] != other.udd_state[k]:
return False
elif not np.array_equal(self.udd_state[k], other.udd_state[k]):
return False
return (self.time == other.time and
np.array_equal(self.qpos, other.qpos) and
np.array_equal(self.qvel, other.qvel) and
np.array_equal(self.act, other.act))
def __ne__(self, other):
return not self.__eq__(other)
def flatten(self):
""" Flattens a state into a numpy array of numbers."""
if self.act is None:
act = np.empty(0)
else:
act = self.act
state_tuple = ([self.time], self.qpos, self.qvel, act,
MjSimState._flatten_dict(self.udd_state))
return np.concatenate(state_tuple)
@staticmethod
def _flatten_dict(d):
a = []
for k in sorted(d.keys()):
v = d[k]
if isinstance(v, Number):
a.extend([v])
else:
a.extend(v.ravel())
return np.array(a)
@staticmethod
def from_flattened(array, sim):
idx_time = 0
idx_qpos = idx_time + 1
idx_qvel = idx_qpos + sim.model.nq
idx_act = idx_qvel + sim.model.nv
idx_udd = idx_act + sim.model.na
time = array[idx_time]
qpos = array[idx_qpos:idx_qpos + sim.model.nq]
qvel = array[idx_qvel:idx_qvel + sim.model.nv]
if sim.model.na == 0:
act = None
else:
act = array[idx_act:idx_act + sim.model.na]
flat_udd_state = array[idx_udd:]
udd_state = MjSimState._unflatten_dict(flat_udd_state, sim.udd_state)
return MjSimState(time, qpos, qvel, act, udd_state)
@staticmethod
def _unflatten_dict(a, schema_example):
d = {}
idx = 0
for k in sorted(schema_example.keys()):
schema_val = schema_example[k]
if isinstance(schema_val, Number):
val = a[idx]
idx += 1
d[k] = val
else:
assert isinstance(schema_val, np.ndarray)
val_array = a[idx:idx+schema_val.size]
idx += schema_val.size
val = np.array(val_array).reshape(schema_val.shape)
d[k] = val
return d