def compare_batch_dimensions()

in tensorflow_graphics/util/shape.py [0:0]


def compare_batch_dimensions(
    tensors: Union[List[tf.Tensor], Tuple[tf.Tensor]],
    last_axes: Union[int, List[int], Tuple[int]],
    broadcast_compatible: bool,
    initial_axes: Union[int, List[int], Tuple[int]] = 0,
    tensor_names: Optional[Union[List[str], Tuple[str]]] = None) -> None:
  """Compares batch dimensions for tensors with static shapes.

  Args:
    tensors: A list or tuple of tensors with static shapes to compare.
    last_axes: An `int` or a list or tuple of `int`s with the same length as
      `tensors`. If an `int`, it is assumed to be the same for all the tensors.
      Each entry should correspond to the last axis of the batch (with zero
      based indices). For instance, if there is only a single batch dimension,
      last axis should be `0`.
    broadcast_compatible: A 'bool', whether the batch shapes can be broadcast
      compatible in the numpy sense.
    initial_axes: An `int` or a list or tuple of `int`s with the same length as
      `tensors`. If an `int`, it is assumed to be the same for all the tensors.
      Each entry should correspond to the first axis of the batch (with zero
      based indices). Default value is `0`.
    tensor_names: Names of `tensors` to be used in the error message if one is
      thrown. If left as `None`, `tensor_i` is used.

  Raises:
    ValueError: If inputs have unexpected types, or if given axes are out of
      bounds, or if the check fails.
  """
  _check_tensors(tensors, 'tensors')
  if isinstance(initial_axes, int):
    initial_axes = [initial_axes] * len(tensors)
  if isinstance(last_axes, int):
    last_axes = [last_axes] * len(tensors)
  _check_tensor_axis_lists(tensors, 'tensors', initial_axes, 'initial_axes')
  _check_tensor_axis_lists(tensors, 'tensors', last_axes, 'last_axes')
  initial_axes = _fix_axes(tensors, initial_axes, allow_negative=True)
  last_axes = _fix_axes(tensors, last_axes, allow_negative=True)
  batch_shapes = [
      tensor.shape[init:last + 1]
      for tensor, init, last in zip(tensors, initial_axes, last_axes)
  ]
  if tensor_names is None:
    tensor_names = _give_default_names(tensors, 'tensor')
  if not broadcast_compatible:
    batch_ndims = [batch_shape.ndims for batch_shape in batch_shapes]
    batch_shapes = [batch_shape.as_list() for batch_shape in batch_shapes]
    if not _all_are_equal(batch_ndims):
      # If not all batch shapes have the same length, they cannot be identical.
      _raise_error(tensor_names, batch_shapes)
    for dims in zip(*batch_shapes):
      if _all_are_equal(dims):
        # Continue if all dimensions are None or have the same value.
        continue
      if None not in dims:
        # If all dimensions are known at this point, they are not identical.
        _raise_error(tensor_names, batch_shapes)
      # At this point dims must consist of both None's and int's.
      if len(set(dims)) != 2:
        # set(dims) should return (None, some_int).
        # Otherwise shapes are not identical.
        _raise_error(tensor_names, batch_shapes)
  else:
    if not all(
        is_broadcast_compatible(shape1, shape2)
        for shape1, shape2 in itertools.combinations(batch_shapes, 2)):
      raise ValueError(
          'Not all batch dimensions are broadcast-compatible: {}'.format([
              (name, batch_shape.as_list())
              for name, batch_shape in zip(tensor_names, batch_shapes)
          ]))