in tensorflow_graphics/util/shape.py [0:0]
def check_static(tensor: tf.Tensor,
has_rank: Optional[int] = None,
has_rank_greater_than: Optional[int] = None,
has_rank_less_than: Optional[int] = None,
has_dim_equals=None,
has_dim_greater_than=None,
has_dim_less_than=None,
tensor_name: str = 'tensor') -> None:
# TODO(cengizo): Typing for has_dim_equals, has_dim_greater(less)_than.
"""Checks static shapes for rank and dimension constraints.
This function can be used to check a tensor's shape for multiple rank and
dimension constraints at the same time.
Args:
tensor: Any tensor with a static shape.
has_rank: An int or `None`. If not `None`, the function checks if the rank
of the `tensor` equals to `has_rank`.
has_rank_greater_than: An int or `None`. If not `None`, the function checks
if the rank of the `tensor` is greater than `has_rank_greater_than`.
has_rank_less_than: An int or `None`. If not `None`, the function checks if
the rank of the `tensor` is less than `has_rank_less_than`.
has_dim_equals: Either a tuple or list containing a single pair of `int`s,
or a list or tuple containing multiple such pairs. Each pair is in the
form (`axis`, `dim`), which means the function should check if
`tensor.shape[axis] == dim`.
has_dim_greater_than: Either a tuple or list containing a single pair of
`int`s, or a list or tuple containing multiple such pairs. Each pair is in
the form (`axis`, `dim`), which means the function should check if
`tensor.shape[axis] > dim`.
has_dim_less_than: Either a tuple or list containing a single pair of
`int`s, or a list or tuple containing multiple such pairs. Each pair is in
the form (`axis`, `dim`), which means the function should check if
`tensor.shape[axis] < dim`.
tensor_name: A name for `tensor` to be used in the error message if one is
thrown.
Raises:
ValueError: If any input is not of the expected types, or if one of the
checks described above fails.
"""
rank = tensor.shape.ndims
def _raise_value_error_for_rank(variable, error_msg):
raise ValueError(
'{} must have a rank {} {}, but it has rank {} and shape {}'.format(
tensor_name, error_msg, variable, rank, tensor.shape.as_list()))
def _raise_value_error_for_dim(tensor_name, error_msg, axis, value):
raise ValueError(
'{} must have {} {} dimensions in axis {}, but it has shape {}'.format(
tensor_name, error_msg, value, axis, tensor.shape.as_list()))
if has_rank is not None:
_check_type(has_rank, 'has_rank', int)
if rank != has_rank:
_raise_value_error_for_rank(has_rank, 'of')
if has_rank_greater_than is not None:
_check_type(has_rank_greater_than, 'has_rank_greater_than', int)
if rank <= has_rank_greater_than:
_raise_value_error_for_rank(has_rank_greater_than, 'greater than')
if has_rank_less_than is not None:
_check_type(has_rank_less_than, 'has_rank_less_than', int)
if rank >= has_rank_less_than:
_raise_value_error_for_rank(has_rank_less_than, 'less than')
if has_dim_equals is not None:
_check_type(has_dim_equals, 'has_dim_equals', (list, tuple))
has_dim_equals = _fix_axis_dim_pairs(has_dim_equals, 'has_dim_equals')
for axis, value in has_dim_equals:
if _get_dim(tensor, axis) != value:
_raise_value_error_for_dim(tensor_name, 'exactly', axis, value)
if has_dim_greater_than is not None:
_check_type(has_dim_greater_than, 'has_dim_greater_than', (list, tuple))
has_dim_greater_than = _fix_axis_dim_pairs(has_dim_greater_than,
'has_dim_greater_than')
for axis, value in has_dim_greater_than:
if not _get_dim(tensor, axis) > value:
_raise_value_error_for_dim(tensor_name, 'greater than', axis, value)
if has_dim_less_than is not None:
_check_type(has_dim_less_than, 'has_dim_less_than', (list, tuple))
has_dim_less_than = _fix_axis_dim_pairs(has_dim_less_than,
'has_dim_less_than')
for axis, value in has_dim_less_than:
if not _get_dim(tensor, axis) < value:
_raise_value_error_for_dim(tensor_name, 'less than', axis, value)