def _assert_bound_constraints()

in tensorflow_lattice/python/kronecker_factored_lattice_lib.py [0:0]


def _assert_bound_constraints(weights, units, scale, output_min, output_max,
                              eps):
  """Asserts that weights satisfy monotonicity constraints.

  Args:
    weights: `KroneckerFactoredLattice` weights tensor of shape: `(1,
      lattice_sizes, units * dims, num_terms)`.
    units: Number of units per input dimension.
    scale: Scale variable of shape: `(units, num_terms)`.
    output_min: None or minimum layer output.
    output_max: None or maximum layer output.
    eps: Allowed constraints violation.

  Returns:
    List of monotonicity assertion ops in graph mode or directly executes
    assertions in eager mode and returns a list of NoneType elements.
  """
  bound_asserts = []

  # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms).
  weights_shape = weights.get_shape().as_list()
  _, lattice_sizes, units_times_dims, num_terms = weights_shape
  assert units_times_dims % units == 0
  dims = units_times_dims // units
  weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])

  # If both bounds are specified, we must also have that the maximum output be
  # between -1 and 1.
  if output_min is not None and output_max is not None:
    for term, term_weights in enumerate(tf.unstack(weights, axis=4)):
      max_keypoint_values = tf.reduce_max(
          tf.abs(term_weights), axis=1, keepdims=True)
      max_output_values = tf.reduce_prod(
          max_keypoint_values, axis=3, keepdims=True)
      for unit, unit_max_output_value in enumerate(
          tf.unstack(max_output_values, axis=2)):
        diff = tf.squeeze(1 - unit_max_output_value)
        bound_asserts.append(
            tf.Assert(
                diff >= -eps,
                data=[
                    "Bound violation (max output greater than 1)", "Diff", diff,
                    "Epsilon", eps, "Maximum output value",
                    unit_max_output_value, "Term index", term, "Unit", unit,
                    "Weights", weights
                ]))
  else:
    # If only one bound is specified, we must have that all of our weights are
    # nonnegative at this point. There can be no allowed epsilon error here
    # because of the effect of a negative value.
    total_negative_weights = tf.reduce_sum(tf.cast(weights < 0, tf.int32))
    bound_asserts.append(
        tf.Assert(
            total_negative_weights <= 0,
            data=[
                "Bound violation (negative weights)",
                "Number of negative weights", total_negative_weights, "Weights",
                weights
            ]))

  # If both bounds are specified, scale must be between
  # -(output_max-output_min)/2 and (output_max-output_min)/2. If only output_min
  # is specified, then scale must be nonnegative. If only output_max is
  # specified, then scale must be nonpositive.
  if output_min is not None and output_max is not None:
    bound = (output_max - output_min) / 2.0
    below_bound_scales = tf.reduce_sum(tf.cast(scale < -bound, tf.int32))
    above_bound_scale = tf.reduce_sum(tf.cast(scale > bound, tf.int32))
    bound_asserts.append(
        tf.Assert(
            below_bound_scales + above_bound_scale <= 0,
            data=[
                "Bound violation (scale out of bounds)", "Bound", bound,
                "Scale", scale
            ]))
  elif output_min is not None:
    negative_scales = tf.reduce_sum(tf.cast(scale < 0, tf.int32))
    bound_asserts.append(
        tf.Assert(
            negative_scales <= 0,
            data=[
                "Bound violation (only output_min specified with negative "
                "scale values)", "Scale", scale
            ]))
  elif output_max is not None:
    positive_scales = tf.reduce_sum(tf.cast(scale > 0, tf.int32))
    bound_asserts.append(
        tf.Assert(
            positive_scales <= 0,
            data=[
                "Bound violation (only output_max specified with positive "
                "scale values)", "Scale", scale
            ]))

  return bound_asserts