def __init__()

in tf_agents/environments/wrappers.py [0:0]


  def __init__(self, env: py_environment.PyEnvironment, flat_dtype=None):
    """Creates a FlattenActionWrapper.

    Args:
      env: Environment to wrap.
      flat_dtype: Optional, if set to a np.dtype the flat action_spec uses this
        dtype.

    Raises:
      ValueError: If any of the action_spec shapes ndim > 1.
      ValueError: If dtypes differ across action specs and flat_dtype is not
        set.
    """
    super(FlattenActionWrapper, self).__init__(env)
    self._original_action_spec = env.action_spec()
    flat_action_spec = tf.nest.flatten(env.action_spec())

    if any([len(s.shape) > 1 for s in flat_action_spec]):
      raise ValueError('ActionSpec shapes should all have ndim == 1.')

    if flat_dtype is None and any(
        [s.dtype != flat_action_spec[0].dtype for s in flat_action_spec]):
      raise ValueError(
          'All action_spec dtypes must match, or `flat_dtype` should be set.')

    # shape or 1 to handle scalar shapes ().
    shape = sum([(s.shape and s.shape[0]) or 1 for s in flat_action_spec]),

    if all(
        [isinstance(s, array_spec.BoundedArraySpec) for s in flat_action_spec]):
      minimums = [
          np.broadcast_to(s.minimum, shape=s.shape) for s in flat_action_spec
      ]
      maximums = [
          np.broadcast_to(s.maximum, shape=s.shape) for s in flat_action_spec
      ]

      minimum = np.hstack(minimums)
      maximum = np.hstack(maximums)
      self._action_spec = array_spec.BoundedArraySpec(
          shape=shape,
          dtype=flat_dtype or flat_action_spec[0].dtype,
          minimum=minimum,
          maximum=maximum,
          name='FlattenedActionSpec')
    else:
      self._action_spec = array_spec.ArraySpec(
          shape=shape,
          dtype=flat_dtype or flat_action_spec[0].dtype,
          name='FlattenedActionSpec')

    self._flat_action_spec = flat_action_spec