in tf_agents/environments/wrappers.py [0:0]
def __init__(self, env: py_environment.PyEnvironment, flat_dtype=None):
"""Creates a FlattenActionWrapper.
Args:
env: Environment to wrap.
flat_dtype: Optional, if set to a np.dtype the flat action_spec uses this
dtype.
Raises:
ValueError: If any of the action_spec shapes ndim > 1.
ValueError: If dtypes differ across action specs and flat_dtype is not
set.
"""
super(FlattenActionWrapper, self).__init__(env)
self._original_action_spec = env.action_spec()
flat_action_spec = tf.nest.flatten(env.action_spec())
if any([len(s.shape) > 1 for s in flat_action_spec]):
raise ValueError('ActionSpec shapes should all have ndim == 1.')
if flat_dtype is None and any(
[s.dtype != flat_action_spec[0].dtype for s in flat_action_spec]):
raise ValueError(
'All action_spec dtypes must match, or `flat_dtype` should be set.')
# shape or 1 to handle scalar shapes ().
shape = sum([(s.shape and s.shape[0]) or 1 for s in flat_action_spec]),
if all(
[isinstance(s, array_spec.BoundedArraySpec) for s in flat_action_spec]):
minimums = [
np.broadcast_to(s.minimum, shape=s.shape) for s in flat_action_spec
]
maximums = [
np.broadcast_to(s.maximum, shape=s.shape) for s in flat_action_spec
]
minimum = np.hstack(minimums)
maximum = np.hstack(maximums)
self._action_spec = array_spec.BoundedArraySpec(
shape=shape,
dtype=flat_dtype or flat_action_spec[0].dtype,
minimum=minimum,
maximum=maximum,
name='FlattenedActionSpec')
else:
self._action_spec = array_spec.ArraySpec(
shape=shape,
dtype=flat_dtype or flat_action_spec[0].dtype,
name='FlattenedActionSpec')
self._flat_action_spec = flat_action_spec