def _parameter_control_dependencies()

in tensorflow_probability/python/bijectors/rational_quadratic_spline.py [0:0]


  def _parameter_control_dependencies(self, is_init):
    """Validate parameters."""
    bw, bh, kd = None, None, None
    try:
      shape = tf.broadcast_static_shape(self.bin_widths.shape,
                                        self.bin_heights.shape)
    except ValueError as e:
      raise ValueError('`bin_widths`, `bin_heights` must broadcast: {}'.format(
          str(e)))
    bin_sizes_shape = shape
    try:
      shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1])
    except ValueError as e:
      raise ValueError(
          '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
          'batch axes: {}'.format(str(e)))

    assertions = []
    if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and
        tensorshape_util.is_fully_defined(self.knot_slopes.shape[-1:])):
      if tensorshape_util.rank(self.knot_slopes.shape) > 0:
        num_interior_knots = tensorshape_util.dims(bin_sizes_shape)[-1] - 1
        if tensorshape_util.dims(
            self.knot_slopes.shape)[-1] not in (1, num_interior_knots):
          raise ValueError(
              'Innermost axis of non-scalar `knot_slopes` must broadcast with '
              '{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape))
    elif self.validate_args:
      if is_init != any(
          tensor_util.is_ref(t)
          for t in (self.bin_widths, self.bin_heights, self.knot_slopes)):
        bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
        shape = tf.broadcast_dynamic_shape(
            tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
        assertions.append(
            assert_util.assert_greater(
                tf.shape(shape)[0],
                tf.zeros([], dtype=shape.dtype),
                message='`(bin_widths + bin_heights)[..., :-1]` must broadcast '
                'with `knot_slopes` to at least 1-D.'))

    if not self.validate_args:
      assert not assertions
      return assertions

    if (is_init != tensor_util.is_ref(self.bin_widths) or
        is_init != tensor_util.is_ref(self.bin_heights)):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_near(
              tf.reduce_sum(bw, axis=-1),
              tf.reduce_sum(bh, axis=-1),
              message='`sum(bin_widths, axis=-1)` must equal '
              '`sum(bin_heights, axis=-1)`.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_widths):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      assertions += [
          assert_util.assert_positive(
              bw, message='`bin_widths` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_heights):
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_positive(
              bh, message='`bin_heights` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.knot_slopes):
      kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
      assertions += [
          assert_util.assert_positive(
              kd, message='`knot_slopes` must be positive.'),
      ]
    return assertions