in tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_util.py [0:0]
def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
"""Detect if input should be interpreted as a list of blocks."""
# Tuples and lists of length equal to the number of operators may be
# blockwise.
if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)):
# If the elements of the iterable are not nested, interpret the input as
# blockwise.
if not any(nest.is_nested(x) for x in arg):
return True
else:
arg_dims = [ops.convert_to_tensor(
x).shape[arg_split_dim] for x in arg]
self_dims = [dim.value for dim in block_dimensions]
# If none of the operator dimensions are known, interpret the input as
# blockwise if its matching dimensions are unequal.
if all(self_d is None for self_d in self_dims):
# A nested tuple/list with a single outermost element is not blockwise
if len(arg_dims) == 1:
return False
elif any(dim != arg_dims[0] for dim in arg_dims):
return True
else:
raise ValueError(
"Parsing of the input structure is ambiguous. Please input "
"a blockwise iterable of `Tensor`s or a single `Tensor`.")
# If input dimensions equal the respective (known) blockwise operator
# dimensions, then the input is blockwise.
if all(self_d == arg_d or self_d is None
for self_d, arg_d in zip(self_dims, arg_dims)):
return True
# If input dimensions equals are all equal, and are greater than or equal
# to the sum of the known operator dimensions, interpret the input as
# blockwise.
# input is not blockwise.
self_dim = sum(self_d for self_d in self_dims if self_d is not None)
if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim:
return False
# If none of these conditions is met, the input shape is mismatched.
raise ValueError("Input dimension does not match operator dimension.")
else:
return False