def check_static()

in tensorflow_graphics/util/shape.py [0:0]


def check_static(tensor: tf.Tensor,
                 has_rank: Optional[int] = None,
                 has_rank_greater_than: Optional[int] = None,
                 has_rank_less_than: Optional[int] = None,
                 has_dim_equals=None,
                 has_dim_greater_than=None,
                 has_dim_less_than=None,
                 tensor_name: str = 'tensor') -> None:
  # TODO(cengizo): Typing for has_dim_equals, has_dim_greater(less)_than.
  """Checks static shapes for rank and dimension constraints.

  This function can be used to check a tensor's shape for multiple rank and
  dimension constraints at the same time.

  Args:
    tensor: Any tensor with a static shape.
    has_rank: An int or `None`. If not `None`, the function checks if the rank
      of the `tensor` equals to `has_rank`.
    has_rank_greater_than: An int or `None`. If not `None`, the function checks
      if the rank of the `tensor` is greater than `has_rank_greater_than`.
    has_rank_less_than: An int or `None`. If not `None`, the function checks if
      the rank of the `tensor` is less than `has_rank_less_than`.
    has_dim_equals: Either a tuple or list containing a single pair of `int`s,
      or a list or tuple containing multiple such pairs. Each pair is in the
      form (`axis`, `dim`), which means the function should check if
      `tensor.shape[axis] == dim`.
    has_dim_greater_than: Either a tuple or list containing a single pair of
      `int`s, or a list or tuple containing multiple such pairs. Each pair is in
      the form (`axis`, `dim`), which means the function should check if
      `tensor.shape[axis] > dim`.
    has_dim_less_than: Either a tuple or list containing a single pair of
      `int`s, or a list or tuple containing multiple such pairs. Each pair is in
      the form (`axis`, `dim`), which means the function should check if
      `tensor.shape[axis] < dim`.
    tensor_name: A name for `tensor` to be used in the error message if one is
      thrown.

  Raises:
    ValueError: If any input is not of the expected types, or if one of the
      checks described above fails.
  """
  rank = tensor.shape.ndims

  def _raise_value_error_for_rank(variable, error_msg):
    raise ValueError(
        '{} must have a rank {} {}, but it has rank {} and shape {}'.format(
            tensor_name, error_msg, variable, rank, tensor.shape.as_list()))

  def _raise_value_error_for_dim(tensor_name, error_msg, axis, value):
    raise ValueError(
        '{} must have {} {} dimensions in axis {}, but it has shape {}'.format(
            tensor_name, error_msg, value, axis, tensor.shape.as_list()))

  if has_rank is not None:
    _check_type(has_rank, 'has_rank', int)
    if rank != has_rank:
      _raise_value_error_for_rank(has_rank, 'of')
  if has_rank_greater_than is not None:
    _check_type(has_rank_greater_than, 'has_rank_greater_than', int)
    if rank <= has_rank_greater_than:
      _raise_value_error_for_rank(has_rank_greater_than, 'greater than')
  if has_rank_less_than is not None:
    _check_type(has_rank_less_than, 'has_rank_less_than', int)
    if rank >= has_rank_less_than:
      _raise_value_error_for_rank(has_rank_less_than, 'less than')
  if has_dim_equals is not None:
    _check_type(has_dim_equals, 'has_dim_equals', (list, tuple))
    has_dim_equals = _fix_axis_dim_pairs(has_dim_equals, 'has_dim_equals')
    for axis, value in has_dim_equals:
      if _get_dim(tensor, axis) != value:
        _raise_value_error_for_dim(tensor_name, 'exactly', axis, value)
  if has_dim_greater_than is not None:
    _check_type(has_dim_greater_than, 'has_dim_greater_than', (list, tuple))
    has_dim_greater_than = _fix_axis_dim_pairs(has_dim_greater_than,
                                               'has_dim_greater_than')
    for axis, value in has_dim_greater_than:
      if not _get_dim(tensor, axis) > value:
        _raise_value_error_for_dim(tensor_name, 'greater than', axis, value)
  if has_dim_less_than is not None:
    _check_type(has_dim_less_than, 'has_dim_less_than', (list, tuple))
    has_dim_less_than = _fix_axis_dim_pairs(has_dim_less_than,
                                            'has_dim_less_than')
    for axis, value in has_dim_less_than:
      if not _get_dim(tensor, axis) < value:
        _raise_value_error_for_dim(tensor_name, 'less than', axis, value)