in tensorflow_probability/python/bijectors/rational_quadratic_spline.py [0:0]
def _parameter_control_dependencies(self, is_init):
"""Validate parameters."""
bw, bh, kd = None, None, None
try:
shape = tf.broadcast_static_shape(self.bin_widths.shape,
self.bin_heights.shape)
except ValueError as e:
raise ValueError('`bin_widths`, `bin_heights` must broadcast: {}'.format(
str(e)))
bin_sizes_shape = shape
try:
shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1])
except ValueError as e:
raise ValueError(
'`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
'batch axes: {}'.format(str(e)))
assertions = []
if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and
tensorshape_util.is_fully_defined(self.knot_slopes.shape[-1:])):
if tensorshape_util.rank(self.knot_slopes.shape) > 0:
num_interior_knots = tensorshape_util.dims(bin_sizes_shape)[-1] - 1
if tensorshape_util.dims(
self.knot_slopes.shape)[-1] not in (1, num_interior_knots):
raise ValueError(
'Innermost axis of non-scalar `knot_slopes` must broadcast with '
'{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape))
elif self.validate_args:
if is_init != any(
tensor_util.is_ref(t)
for t in (self.bin_widths, self.bin_heights, self.knot_slopes)):
bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
shape = tf.broadcast_dynamic_shape(
tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
assertions.append(
assert_util.assert_greater(
tf.shape(shape)[0],
tf.zeros([], dtype=shape.dtype),
message='`(bin_widths + bin_heights)[..., :-1]` must broadcast '
'with `knot_slopes` to at least 1-D.'))
if not self.validate_args:
assert not assertions
return assertions
if (is_init != tensor_util.is_ref(self.bin_widths) or
is_init != tensor_util.is_ref(self.bin_heights)):
bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
assertions += [
assert_util.assert_near(
tf.reduce_sum(bw, axis=-1),
tf.reduce_sum(bh, axis=-1),
message='`sum(bin_widths, axis=-1)` must equal '
'`sum(bin_heights, axis=-1)`.'),
]
if is_init != tensor_util.is_ref(self.bin_widths):
bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
assertions += [
assert_util.assert_positive(
bw, message='`bin_widths` must be positive.'),
]
if is_init != tensor_util.is_ref(self.bin_heights):
bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
assertions += [
assert_util.assert_positive(
bh, message='`bin_heights` must be positive.'),
]
if is_init != tensor_util.is_ref(self.knot_slopes):
kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
assertions += [
assert_util.assert_positive(
kd, message='`knot_slopes` must be positive.'),
]
return assertions