in tensorflow_graphics/util/shape.py [0:0]
def _broadcast_shape_helper(shape_x: tf.TensorShape,
shape_y: tf.TensorShape) -> Optional[List[Any]]:
"""Helper function for is_broadcast_compatible and broadcast_shape.
Args:
shape_x: A `TensorShape`.
shape_y: A `TensorShape`.
Returns:
Returns None if the shapes are not broadcast compatible, or a list
containing the broadcasted dimensions otherwise.
"""
# To compute the broadcasted dimensions, we zip together shape_x and shape_y,
# and pad with 1 to make them the same length.
broadcasted_dims = reversed(
list(
six.moves.zip_longest(
reversed(shape_x.dims),
reversed(shape_y.dims),
fillvalue=tf.compat.v1.Dimension(1))))
# Next we combine the dimensions according to the numpy broadcasting rules.
# http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
return_dims = []
for (dim_x, dim_y) in broadcasted_dims:
if dim_x.value is None or dim_y.value is None:
# One or both dimensions is unknown. If either dimension is greater than
# 1, we assume that the program is correct, and the other dimension will
# be broadcast to match it.
if dim_x.value is not None and dim_x.value > 1:
return_dims.append(dim_x)
elif dim_y.value is not None and dim_y.value > 1:
return_dims.append(dim_y)
else:
return_dims.append(None)
elif dim_x.value == 1:
# We will broadcast dim_x to dim_y.
return_dims.append(dim_y)
elif dim_y.value == 1:
# We will broadcast dim_y to dim_x.
return_dims.append(dim_x)
elif dim_x.value == dim_y.value:
# The dimensions are compatible, so output is the same size in that
# dimension.
return_dims.append(dim_x.merge_with(dim_y))
else:
return None
return return_dims