def is_batched_nested_tensors()

in tf_agents/utils/nest_utils.py [0:0]


def is_batched_nested_tensors(tensors,
                              specs,
                              num_outer_dims=1,
                              allow_extra_fields=False,
                              check_dtypes=True):
  """Compares tensors to specs to determine if all tensors are batched or not.

  For each tensor, it checks the dimensions and dtypes with respect to specs.

  Returns `True` if all tensors are batched and `False` if all tensors are
  unbatched.

  Raises a `ValueError` if the shapes are incompatible or a mix of batched and
  unbatched tensors are provided.

  Raises a `TypeError` if tensors' dtypes do not match specs.

  Args:
    tensors: Nested list/tuple/dict of Tensors.
    specs: Nested list/tuple/dict of Tensors or CompositeTensors describing the
      shape of unbatched tensors.
    num_outer_dims: The integer number of dimensions that are considered batch
      dimensions.  Default 1.
    allow_extra_fields: If `True`, then `tensors` may have extra subfields which
      are not in specs.  In this case, the extra subfields
      will not be checked.  For example:  ```python
      tensors = {"a": tf.zeros((3, 4), dtype=tf.float32),
                 "b": tf.zeros((5, 6), dtype=tf.float32)}
      specs = {"a": tf.TensorSpec(shape=(4,), dtype=tf.float32)} assert
        is_batched_nested_tensors(tensors, specs, allow_extra_fields=True) ```
        The above example would raise a ValueError if `allow_extra_fields` was
        False.
    check_dtypes: If `True` will validate that tensors and specs have the same
      dtypes.

  Returns:
    True if all Tensors are batched and False if all Tensors are unbatched.

  Raises:
    ValueError: If
      1. Any of the tensors or specs have shapes with ndims == None, or
      2. The shape of Tensors are not compatible with specs, or
      3. A mix of batched and unbatched tensors are provided.
      4. The tensors are batched but have an incorrect number of outer dims.
    TypeError: If `dtypes` between tensors and specs are not compatible.
  """
  if allow_extra_fields:
    tensors = prune_extra_keys(specs, tensors)

  assert_same_structure(
      tensors,
      specs,
      message='Tensors and specs do not have matching structures')
  flat_tensors = nest.flatten(tensors)
  flat_specs = tf.nest.flatten(specs)

  tensor_shapes = [t.shape for t in flat_tensors]
  tensor_dtypes = [t.dtype for t in flat_tensors]
  spec_shapes = [spec_shape(s) for s in flat_specs]
  spec_dtypes = [t.dtype for t in flat_specs]

  if any(s_shape.rank is None for s_shape in spec_shapes):
    raise ValueError('All specs should have ndims defined.  Saw shapes: %s' %
                     (tf.nest.pack_sequence_as(specs, spec_shapes),))

  if any(t_shape.rank is None for t_shape in tensor_shapes):
    raise ValueError('All tensors should have ndims defined.  Saw shapes: %s' %
                     (tf.nest.pack_sequence_as(specs, tensor_shapes),))

  if (check_dtypes and
      any(s_dtype != t_dtype
          for s_dtype, t_dtype in zip(spec_dtypes, tensor_dtypes))):
    raise TypeError('Tensor dtypes do not match spec dtypes:\n{}\nvs.\n{}'
                    .format(tf.nest.pack_sequence_as(specs, tensor_dtypes),
                            tf.nest.pack_sequence_as(specs, spec_dtypes)))
  is_unbatched = [
      s_shape.is_compatible_with(t_shape)
      for s_shape, t_shape in zip(spec_shapes, tensor_shapes)
  ]

  if all(is_unbatched):
    return False

  tensor_ndims_discrepancy = [
      t_shape.rank - s_shape.rank
      for s_shape, t_shape in zip(spec_shapes, tensor_shapes)
  ]

  tensor_matches_spec = [
      s_shape.is_compatible_with(t_shape[discrepancy:])
      for discrepancy, s_shape, t_shape in zip(
          tensor_ndims_discrepancy, spec_shapes, tensor_shapes)
  ]

  # Check if all tensors match and have correct number of outer_dims.
  is_batched = (
      all(discrepancy == num_outer_dims
          for discrepancy in tensor_ndims_discrepancy) and
      all(tensor_matches_spec))

  if is_batched:
    return True

  # Check if tensors match but have incorrect number of batch dimensions.
  if all(
      discrepancy == tensor_ndims_discrepancy[0]
      for discrepancy in tensor_ndims_discrepancy) and all(tensor_matches_spec):
    return False

  raise ValueError(
      'Received a mix of batched and unbatched Tensors, or Tensors'
      ' are not compatible with Specs.  num_outer_dims: %d.\n'
      'Saw tensor_shapes:\n   %s\n'
      'And spec_shapes:\n   %s' %
      (num_outer_dims, tf.nest.pack_sequence_as(specs, tensor_shapes),
       tf.nest.pack_sequence_as(specs, spec_shapes)))