in tensorflow_graphics/util/shape.py [0:0]
def compare_batch_dimensions(
tensors: Union[List[tf.Tensor], Tuple[tf.Tensor]],
last_axes: Union[int, List[int], Tuple[int]],
broadcast_compatible: bool,
initial_axes: Union[int, List[int], Tuple[int]] = 0,
tensor_names: Optional[Union[List[str], Tuple[str]]] = None) -> None:
"""Compares batch dimensions for tensors with static shapes.
Args:
tensors: A list or tuple of tensors with static shapes to compare.
last_axes: An `int` or a list or tuple of `int`s with the same length as
`tensors`. If an `int`, it is assumed to be the same for all the tensors.
Each entry should correspond to the last axis of the batch (with zero
based indices). For instance, if there is only a single batch dimension,
last axis should be `0`.
broadcast_compatible: A 'bool', whether the batch shapes can be broadcast
compatible in the numpy sense.
initial_axes: An `int` or a list or tuple of `int`s with the same length as
`tensors`. If an `int`, it is assumed to be the same for all the tensors.
Each entry should correspond to the first axis of the batch (with zero
based indices). Default value is `0`.
tensor_names: Names of `tensors` to be used in the error message if one is
thrown. If left as `None`, `tensor_i` is used.
Raises:
ValueError: If inputs have unexpected types, or if given axes are out of
bounds, or if the check fails.
"""
_check_tensors(tensors, 'tensors')
if isinstance(initial_axes, int):
initial_axes = [initial_axes] * len(tensors)
if isinstance(last_axes, int):
last_axes = [last_axes] * len(tensors)
_check_tensor_axis_lists(tensors, 'tensors', initial_axes, 'initial_axes')
_check_tensor_axis_lists(tensors, 'tensors', last_axes, 'last_axes')
initial_axes = _fix_axes(tensors, initial_axes, allow_negative=True)
last_axes = _fix_axes(tensors, last_axes, allow_negative=True)
batch_shapes = [
tensor.shape[init:last + 1]
for tensor, init, last in zip(tensors, initial_axes, last_axes)
]
if tensor_names is None:
tensor_names = _give_default_names(tensors, 'tensor')
if not broadcast_compatible:
batch_ndims = [batch_shape.ndims for batch_shape in batch_shapes]
batch_shapes = [batch_shape.as_list() for batch_shape in batch_shapes]
if not _all_are_equal(batch_ndims):
# If not all batch shapes have the same length, they cannot be identical.
_raise_error(tensor_names, batch_shapes)
for dims in zip(*batch_shapes):
if _all_are_equal(dims):
# Continue if all dimensions are None or have the same value.
continue
if None not in dims:
# If all dimensions are known at this point, they are not identical.
_raise_error(tensor_names, batch_shapes)
# At this point dims must consist of both None's and int's.
if len(set(dims)) != 2:
# set(dims) should return (None, some_int).
# Otherwise shapes are not identical.
_raise_error(tensor_names, batch_shapes)
else:
if not all(
is_broadcast_compatible(shape1, shape2)
for shape1, shape2 in itertools.combinations(batch_shapes, 2)):
raise ValueError(
'Not all batch dimensions are broadcast-compatible: {}'.format([
(name, batch_shape.as_list())
for name, batch_shape in zip(tensor_names, batch_shapes)
]))