in tf_agents/utils/nest_utils.py [0:0]
def assert_matching_dtypes_and_inner_shapes(tensors_or_specs,
specs,
caller,
tensors_name,
specs_name,
allow_extra_fields=False):
"""Returns `True` if tensors and specs have matching dtypes and inner shapes.
Args:
tensors_or_specs: A nest of `Tensor` like or `tf.TypeSpec` objects.
specs: A nest of `tf.TypeSpec` objects.
caller: The object calling `assert...`.
tensors_name: (str) Name to use for the tensors in case of an error.
specs_name: (str) Name to use for the specs in case of an error.
allow_extra_fields: If `True`, then `tensors` may contain more keys or list
fields than strictly required by `specs`.
Raises:
ValueError: If the tensors do not match the specs' dtypes or their inner
shapes do not match the specs' shapes.
"""
if allow_extra_fields:
tensors_or_specs = prune_extra_keys(specs, tensors_or_specs)
assert_same_structure(
tensors_or_specs,
specs,
message=('{}: {} and {} do not have matching structures'.format(
caller, tensors_name, specs_name)))
flat_tensors = nest.flatten(tensors_or_specs)
flat_specs = tf.nest.flatten(specs)
def _convert(t, s):
if not isinstance(t, tf.TypeSpec) and not tf.is_tensor(t):
t = tf.convert_to_tensor(t, dtype_hint=s.dtype)
return t
flat_tensors = [_convert(t, s) for (t, s) in zip(flat_tensors, flat_specs)]
tensor_shapes = [t.shape for t in flat_tensors]
tensor_dtypes = [t.dtype for t in flat_tensors]
spec_shapes = [spec_shape(s) for s in flat_specs]
spec_dtypes = [t.dtype for t in flat_specs]
compatible = True
if any(s_dtype != t_dtype
for s_dtype, t_dtype in zip(spec_dtypes, tensor_dtypes)):
compatible = False
else:
for s_shape, t_shape in zip(spec_shapes, tensor_shapes):
if s_shape.ndims in (0, None) or t_shape.ndims is None:
continue
if s_shape.ndims > t_shape.ndims:
compatible = False
break
if not s_shape.is_compatible_with(t_shape[-s_shape.ndims:]):
compatible = False
break
if not compatible:
get_dtypes = lambda v: tf.nest.map_structure(lambda x: x.dtype, v)
get_shapes = lambda v: tf.nest.map_structure(spec_shape, v)
raise ValueError('{}: Inconsistent dtypes or shapes between {} and {}.\n'
'dtypes:\n{}\nvs.\n{}.\n'
'shapes:\n{}\nvs.\n{}.'.format(
caller,
tensors_name,
specs_name,
get_dtypes(tensors_or_specs),
get_dtypes(specs),
get_shapes(tensors_or_specs),
get_shapes(specs)))