def assert_matching_dtypes_and_inner_shapes()

in tf_agents/utils/nest_utils.py [0:0]


def assert_matching_dtypes_and_inner_shapes(tensors_or_specs,
                                            specs,
                                            caller,
                                            tensors_name,
                                            specs_name,
                                            allow_extra_fields=False):
  """Returns `True` if tensors and specs have matching dtypes and inner shapes.

  Args:
    tensors_or_specs: A nest of `Tensor` like or `tf.TypeSpec` objects.
    specs: A nest of `tf.TypeSpec` objects.
    caller: The object calling `assert...`.
    tensors_name: (str) Name to use for the tensors in case of an error.
    specs_name: (str) Name to use for the specs in case of an error.
    allow_extra_fields: If `True`, then `tensors` may contain more keys or list
      fields than strictly required by `specs`.

  Raises:
    ValueError: If the tensors do not match the specs' dtypes or their inner
      shapes do not match the specs' shapes.
  """
  if allow_extra_fields:
    tensors_or_specs = prune_extra_keys(specs, tensors_or_specs)
  assert_same_structure(
      tensors_or_specs,
      specs,
      message=('{}: {} and {} do not have matching structures'.format(
          caller, tensors_name, specs_name)))

  flat_tensors = nest.flatten(tensors_or_specs)
  flat_specs = tf.nest.flatten(specs)
  def _convert(t, s):
    if not isinstance(t, tf.TypeSpec) and not tf.is_tensor(t):
      t = tf.convert_to_tensor(t, dtype_hint=s.dtype)
    return t

  flat_tensors = [_convert(t, s) for (t, s) in zip(flat_tensors, flat_specs)]

  tensor_shapes = [t.shape for t in flat_tensors]
  tensor_dtypes = [t.dtype for t in flat_tensors]
  spec_shapes = [spec_shape(s) for s in flat_specs]
  spec_dtypes = [t.dtype for t in flat_specs]

  compatible = True

  if any(s_dtype != t_dtype
         for s_dtype, t_dtype in zip(spec_dtypes, tensor_dtypes)):
    compatible = False
  else:
    for s_shape, t_shape in zip(spec_shapes, tensor_shapes):
      if s_shape.ndims in (0, None) or t_shape.ndims is None:
        continue
      if s_shape.ndims > t_shape.ndims:
        compatible = False
        break
      if not s_shape.is_compatible_with(t_shape[-s_shape.ndims:]):
        compatible = False
        break

  if not compatible:
    get_dtypes = lambda v: tf.nest.map_structure(lambda x: x.dtype, v)
    get_shapes = lambda v: tf.nest.map_structure(spec_shape, v)
    raise ValueError('{}: Inconsistent dtypes or shapes between {} and {}.\n'
                     'dtypes:\n{}\nvs.\n{}.\n'
                     'shapes:\n{}\nvs.\n{}.'.format(
                         caller,
                         tensors_name,
                         specs_name,
                         get_dtypes(tensors_or_specs),
                         get_dtypes(specs),
                         get_shapes(tensors_or_specs),
                         get_shapes(specs)))