def get_outer_rank()

in tf_agents/utils/nest_utils.py [0:0]


def get_outer_rank(tensors, specs):
  """Compares tensors to specs to determine the number of batch dimensions.

  For each tensor, it checks the dimensions with respect to specs and
  returns the number of batch dimensions if all nested tensors and
  specs agree with each other.

  Args:
    tensors: Nested list/tuple/dict of Tensors or SparseTensors.
    specs: Nested list/tuple/dict of TensorSpecs, describing the shape of
      unbatched tensors.

  Returns:
    The number of outer dimensions for all Tensors (zero if all are
      unbatched or empty).
  Raises:
    ValueError: If
      1. Any of the tensors or specs have shapes with ndims == None, or
      2. The shape of Tensors are not compatible with specs, or
      3. A mix of batched and unbatched tensors are provided.
      4. The tensors are batched but have an incorrect number of outer dims.
  """
  assert_same_structure(
      tensors,
      specs,
      message='Tensors and specs do not have matching structures')
  tensor_shapes = [t.shape for t in tf.nest.flatten(tensors)]
  spec_shapes = [spec_shape(s) for s in tf.nest.flatten(specs)]

  if any(s_shape.rank is None for s_shape in spec_shapes):
    raise ValueError('All specs should have ndims defined.  Saw shapes: %s' %
                     spec_shapes)

  if any(t_shape.rank is None for t_shape in tensor_shapes):
    raise ValueError('All tensors should have ndims defined.  Saw shapes: %s' %
                     tensor_shapes)

  is_unbatched = [
      s_shape.is_compatible_with(t_shape)
      for s_shape, t_shape in zip(spec_shapes, tensor_shapes)
  ]
  if all(is_unbatched):
    return 0

  tensor_ndims_discrepancy = [
      t_shape.rank - s_shape.rank
      for s_shape, t_shape in zip(spec_shapes, tensor_shapes)
  ]

  tensor_matches_spec = [
      s_shape.is_compatible_with(t_shape[discrepancy:])
      for discrepancy, s_shape, t_shape in zip(
          tensor_ndims_discrepancy, spec_shapes, tensor_shapes)
  ]

  # At this point we are guaranteed to have at least one tensor/spec.
  num_outer_dims = tensor_ndims_discrepancy[0]

  # Check if all tensors match and have correct number of batch dimensions.
  is_batched = (
      all(discrepancy == num_outer_dims
          for discrepancy in tensor_ndims_discrepancy) and
      all(tensor_matches_spec))

  if is_batched:
    return num_outer_dims

  # Check if tensors match but have incorrect number of batch dimensions.
  incorrect_batch_dims = (
      tensor_ndims_discrepancy and
      all(discrepancy == tensor_ndims_discrepancy[0] and discrepancy >= 0
          for discrepancy in tensor_ndims_discrepancy) and
      all(tensor_matches_spec))

  if incorrect_batch_dims:
    raise ValueError('Received tensors with %d outer dimensions. '
                     'Expected %d.' %
                     (tensor_ndims_discrepancy[0], num_outer_dims))

  raise ValueError('Received a mix of batched and unbatched Tensors, or Tensors'
                   ' are not compatible with Specs.  num_outer_dims: %d.\n'
                   'Saw tensor_shapes:\n   %s\n'
                   'And spec_shapes:\n   %s' %
                   (num_outer_dims, tensor_shapes, spec_shapes))