in tensorflow_lattice/python/kronecker_factored_lattice_lib.py [0:0]
def _assert_bound_constraints(weights, units, scale, output_min, output_max,
eps):
"""Asserts that weights satisfy monotonicity constraints.
Args:
weights: `KroneckerFactoredLattice` weights tensor of shape: `(1,
lattice_sizes, units * dims, num_terms)`.
units: Number of units per input dimension.
scale: Scale variable of shape: `(units, num_terms)`.
output_min: None or minimum layer output.
output_max: None or maximum layer output.
eps: Allowed constraints violation.
Returns:
List of monotonicity assertion ops in graph mode or directly executes
assertions in eager mode and returns a list of NoneType elements.
"""
bound_asserts = []
# Recall that w.shape is (1, lattice_sizes, units * dims, num_terms).
weights_shape = weights.get_shape().as_list()
_, lattice_sizes, units_times_dims, num_terms = weights_shape
assert units_times_dims % units == 0
dims = units_times_dims // units
weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])
# If both bounds are specified, we must also have that the maximum output be
# between -1 and 1.
if output_min is not None and output_max is not None:
for term, term_weights in enumerate(tf.unstack(weights, axis=4)):
max_keypoint_values = tf.reduce_max(
tf.abs(term_weights), axis=1, keepdims=True)
max_output_values = tf.reduce_prod(
max_keypoint_values, axis=3, keepdims=True)
for unit, unit_max_output_value in enumerate(
tf.unstack(max_output_values, axis=2)):
diff = tf.squeeze(1 - unit_max_output_value)
bound_asserts.append(
tf.Assert(
diff >= -eps,
data=[
"Bound violation (max output greater than 1)", "Diff", diff,
"Epsilon", eps, "Maximum output value",
unit_max_output_value, "Term index", term, "Unit", unit,
"Weights", weights
]))
else:
# If only one bound is specified, we must have that all of our weights are
# nonnegative at this point. There can be no allowed epsilon error here
# because of the effect of a negative value.
total_negative_weights = tf.reduce_sum(tf.cast(weights < 0, tf.int32))
bound_asserts.append(
tf.Assert(
total_negative_weights <= 0,
data=[
"Bound violation (negative weights)",
"Number of negative weights", total_negative_weights, "Weights",
weights
]))
# If both bounds are specified, scale must be between
# -(output_max-output_min)/2 and (output_max-output_min)/2. If only output_min
# is specified, then scale must be nonnegative. If only output_max is
# specified, then scale must be nonpositive.
if output_min is not None and output_max is not None:
bound = (output_max - output_min) / 2.0
below_bound_scales = tf.reduce_sum(tf.cast(scale < -bound, tf.int32))
above_bound_scale = tf.reduce_sum(tf.cast(scale > bound, tf.int32))
bound_asserts.append(
tf.Assert(
below_bound_scales + above_bound_scale <= 0,
data=[
"Bound violation (scale out of bounds)", "Bound", bound,
"Scale", scale
]))
elif output_min is not None:
negative_scales = tf.reduce_sum(tf.cast(scale < 0, tf.int32))
bound_asserts.append(
tf.Assert(
negative_scales <= 0,
data=[
"Bound violation (only output_min specified with negative "
"scale values)", "Scale", scale
]))
elif output_max is not None:
positive_scales = tf.reduce_sum(tf.cast(scale > 0, tf.int32))
bound_asserts.append(
tf.Assert(
positive_scales <= 0,
data=[
"Bound violation (only output_max specified with positive "
"scale values)", "Scale", scale
]))
return bound_asserts