in tf_agents/utils/nest_utils.py [0:0]
def is_batched_nested_tensors(tensors,
specs,
num_outer_dims=1,
allow_extra_fields=False,
check_dtypes=True):
"""Compares tensors to specs to determine if all tensors are batched or not.
For each tensor, it checks the dimensions and dtypes with respect to specs.
Returns `True` if all tensors are batched and `False` if all tensors are
unbatched.
Raises a `ValueError` if the shapes are incompatible or a mix of batched and
unbatched tensors are provided.
Raises a `TypeError` if tensors' dtypes do not match specs.
Args:
tensors: Nested list/tuple/dict of Tensors.
specs: Nested list/tuple/dict of Tensors or CompositeTensors describing the
shape of unbatched tensors.
num_outer_dims: The integer number of dimensions that are considered batch
dimensions. Default 1.
allow_extra_fields: If `True`, then `tensors` may have extra subfields which
are not in specs. In this case, the extra subfields
will not be checked. For example: ```python
tensors = {"a": tf.zeros((3, 4), dtype=tf.float32),
"b": tf.zeros((5, 6), dtype=tf.float32)}
specs = {"a": tf.TensorSpec(shape=(4,), dtype=tf.float32)} assert
is_batched_nested_tensors(tensors, specs, allow_extra_fields=True) ```
The above example would raise a ValueError if `allow_extra_fields` was
False.
check_dtypes: If `True` will validate that tensors and specs have the same
dtypes.
Returns:
True if all Tensors are batched and False if all Tensors are unbatched.
Raises:
ValueError: If
1. Any of the tensors or specs have shapes with ndims == None, or
2. The shape of Tensors are not compatible with specs, or
3. A mix of batched and unbatched tensors are provided.
4. The tensors are batched but have an incorrect number of outer dims.
TypeError: If `dtypes` between tensors and specs are not compatible.
"""
if allow_extra_fields:
tensors = prune_extra_keys(specs, tensors)
assert_same_structure(
tensors,
specs,
message='Tensors and specs do not have matching structures')
flat_tensors = nest.flatten(tensors)
flat_specs = tf.nest.flatten(specs)
tensor_shapes = [t.shape for t in flat_tensors]
tensor_dtypes = [t.dtype for t in flat_tensors]
spec_shapes = [spec_shape(s) for s in flat_specs]
spec_dtypes = [t.dtype for t in flat_specs]
if any(s_shape.rank is None for s_shape in spec_shapes):
raise ValueError('All specs should have ndims defined. Saw shapes: %s' %
(tf.nest.pack_sequence_as(specs, spec_shapes),))
if any(t_shape.rank is None for t_shape in tensor_shapes):
raise ValueError('All tensors should have ndims defined. Saw shapes: %s' %
(tf.nest.pack_sequence_as(specs, tensor_shapes),))
if (check_dtypes and
any(s_dtype != t_dtype
for s_dtype, t_dtype in zip(spec_dtypes, tensor_dtypes))):
raise TypeError('Tensor dtypes do not match spec dtypes:\n{}\nvs.\n{}'
.format(tf.nest.pack_sequence_as(specs, tensor_dtypes),
tf.nest.pack_sequence_as(specs, spec_dtypes)))
is_unbatched = [
s_shape.is_compatible_with(t_shape)
for s_shape, t_shape in zip(spec_shapes, tensor_shapes)
]
if all(is_unbatched):
return False
tensor_ndims_discrepancy = [
t_shape.rank - s_shape.rank
for s_shape, t_shape in zip(spec_shapes, tensor_shapes)
]
tensor_matches_spec = [
s_shape.is_compatible_with(t_shape[discrepancy:])
for discrepancy, s_shape, t_shape in zip(
tensor_ndims_discrepancy, spec_shapes, tensor_shapes)
]
# Check if all tensors match and have correct number of outer_dims.
is_batched = (
all(discrepancy == num_outer_dims
for discrepancy in tensor_ndims_discrepancy) and
all(tensor_matches_spec))
if is_batched:
return True
# Check if tensors match but have incorrect number of batch dimensions.
if all(
discrepancy == tensor_ndims_discrepancy[0]
for discrepancy in tensor_ndims_discrepancy) and all(tensor_matches_spec):
return False
raise ValueError(
'Received a mix of batched and unbatched Tensors, or Tensors'
' are not compatible with Specs. num_outer_dims: %d.\n'
'Saw tensor_shapes:\n %s\n'
'And spec_shapes:\n %s' %
(num_outer_dims, tf.nest.pack_sequence_as(specs, tensor_shapes),
tf.nest.pack_sequence_as(specs, spec_shapes)))