def assert_constraints()

in tensorflow_lattice/python/lattice_lib.py [0:0]


def assert_constraints(weights,
                       lattice_sizes,
                       monotonicities,
                       edgeworth_trusts,
                       trapezoid_trusts,
                       monotonic_dominances,
                       range_dominances,
                       joint_monotonicities,
                       joint_unimodalities,
                       output_min=None,
                       output_max=None,
                       eps=1e-6):
  """Asserts that weights satisfy constraints.

  Args:
    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
    lattice_sizes: List or tuple of integers which represents lattice sizes.
    monotonicities: Monotonicity constraints.
    edgeworth_trusts: Edgeworth trust constraints.
    trapezoid_trusts: Trapezoid trust constraints.
    monotonic_dominances: Monotonic dominance constraints.
    range_dominances: Range dominance constraints.
    joint_monotonicities: Joint monotonicity constraints.
    joint_unimodalities: Joint unimodality constraints.
    output_min: None or lower bound constraints.
    output_max: None or upper bound constraints.
    eps: Allowed constraints violation.

  Returns:
    List of assertion ops in graph mode or directly executes assertions in eager
    mode.
  """
  # TODO: actually assert them.
  del joint_unimodalities

  if weights.shape[1] > 1:
    lattice_sizes = lattice_sizes + [int(weights.shape[1])]
    if monotonicities:
      monotonicities = monotonicities + [0]
  weights = tf.reshape(weights, shape=lattice_sizes)
  asserts = []

  for i in range(len(monotonicities or [])):
    if monotonicities[i] != 1:
      continue
    weights_layers = tf.unstack(weights, axis=i)

    for j in range(1, len(weights_layers)):
      diff = tf.reduce_min(weights_layers[j] - weights_layers[j - 1])
      asserts.append(
          tf.Assert(
              diff >= -eps,
              data=[
                  "Monotonicity violation", "Feature index:", i,
                  "Min monotonicity diff:", diff, "Upper layer number:", j,
                  "Epsilon:", eps, "Layers:", weights_layers[j],
                  weights_layers[j - 1]
              ]))

  for main_dim, cond_dim, cond_direction in edgeworth_trusts or []:
    weights_layers = _unstack_nd(weights, [main_dim, cond_dim])
    for i in range(lattice_sizes[main_dim] - 1):
      for j in range(lattice_sizes[cond_dim] - 1):
        diff = tf.reduce_min(
            cond_direction *
            ((weights_layers[i + 1][j + 1] - weights_layers[i][j + 1]) -
             (weights_layers[i + 1][j] - weights_layers[i][j])))
        asserts.append(
            tf.Assert(
                diff >= -eps,
                data=[
                    "Edgeworth trust violation", "Feature indices:", main_dim,
                    ",", cond_dim, "Min trust diff:", diff, "Epsilon:", eps,
                    "Layers:", weights_layers[i + 1][j + 1],
                    weights_layers[i][j + 1], weights_layers[i + 1][j],
                    weights_layers[i][j]
                ]))

  for main_dim, cond_dim, cond_direction in trapezoid_trusts or []:
    weights_layers = _unstack_nd(weights, [main_dim, cond_dim])
    max_main_dim = lattice_sizes[main_dim] - 1
    for j in range(lattice_sizes[cond_dim] - 1):
      lhs_diff = tf.reduce_min(
          cond_direction * (weights_layers[0][j] - weights_layers[0][j + 1]))
      asserts.append(
          tf.Assert(
              lhs_diff >= -eps,
              data=[
                  "Trapezoid trust violation", "Feature indices:", main_dim,
                  ",", cond_dim, "Min trust diff:", lhs_diff, "Epsilon:", eps,
                  "Layers:", weights_layers[0][j], weights_layers[0][j + 1]
              ]))
      rhs_diff = tf.reduce_min(cond_direction *
                               (weights_layers[max_main_dim][j + 1] -
                                weights_layers[max_main_dim][j]))
      asserts.append(
          tf.Assert(
              rhs_diff >= -eps,
              data=[
                  "Trapezoid trust violation", "Feature indices:", main_dim,
                  ",", cond_dim, "Min trust diff:", rhs_diff, "Epsilon:", eps,
                  "Layers:", weights_layers[max_main_dim][j + 1],
                  weights_layers[max_main_dim][j]
              ]))

  for dominant_dim, weak_dim in monotonic_dominances or []:
    weights_layers = _unstack_nd(weights, [dominant_dim, weak_dim])
    for i in range(lattice_sizes[dominant_dim] - 1):
      for j in range(lattice_sizes[weak_dim] - 1):
        midpoint = (weights_layers[i + 1][j + 1] + weights_layers[i][j]) / 2
        dominant_diff = tf.reduce_min(weights_layers[i + 1][j] - midpoint)
        asserts.append(
            tf.Assert(
                dominant_diff >= -eps,
                data=[
                    "Dominance violation", "Feature indices:", dominant_dim,
                    ",", weak_dim, "Min dominance diff:", dominant_diff,
                    "Epsilon:", eps, "Layers:", weights_layers[i][j],
                    weights_layers[i + 1][j], weights_layers[i + 1][j + 1]
                ]))
        weak_diff = tf.reduce_min(midpoint - weights_layers[i][j + 1])
        asserts.append(
            tf.Assert(
                weak_diff >= -eps,
                data=[
                    "Dominance violation", "Feature indices:", dominant_dim,
                    ",", weak_dim, "Min dominance diff:", weak_diff, "Epsilon:",
                    eps, "Layers:", weights_layers[i][j],
                    weights_layers[i + 1][j], weights_layers[i + 1][j + 1]
                ]))

  for dominant_dim, weak_dim in range_dominances or []:
    weights_layers = _unstack_nd(weights, [dominant_dim, weak_dim])
    dom_dim_size = lattice_sizes[dominant_dim]
    weak_dim_size = lattice_sizes[weak_dim]
    for i in range(dom_dim_size):
      for j in range(weak_dim_size):
        diff = tf.reduce_min(
            (weights_layers[dom_dim_size - 1][j] - weights_layers[0][j]) -
            (weights_layers[i][weak_dim_size - 1] - weights_layers[i][0]))
        asserts.append(
            tf.Assert(
                diff >= -eps,
                data=[
                    "Range dominance violation", "Feature indices:",
                    dominant_dim, ",", weak_dim, "Min dominance diff:", diff,
                    "Epsilon:", eps, "Layers:",
                    weights_layers[dom_dim_size - 1][j], weights_layers[0][j],
                    weights_layers[i][weak_dim_size - 1], weights_layers[i][0]
                ]))

  for dim1, dim2 in joint_monotonicities or []:
    weights_layers = _unstack_nd(weights, [dim1, dim2])
    for i in range(lattice_sizes[dim1] - 1):
      for j in range(lattice_sizes[dim2] - 1):
        midpoint = (weights_layers[i + 1][j] + weights_layers[i][j + 1]) / 2
        lower_triangle_diff = tf.reduce_min(weights_layers[i + 1][j + 1] -
                                            midpoint)
        asserts.append(
            tf.Assert(
                lower_triangle_diff >= -eps,
                data=[
                    "Joint monotonicity violation", "Feature indices:", dim1,
                    ",", dim2, "Min lower triangle diff:", lower_triangle_diff,
                    "Epsilon:", eps, "Layers:", weights_layers[i + 1][j + 1],
                    weights_layers[i + 1][j], weights_layers[i][j + 1]
                ]))
        upper_triangle_diff = tf.reduce_min(midpoint - weights_layers[i][j])
        asserts.append(
            tf.Assert(
                upper_triangle_diff >= -eps,
                data=[
                    "Joint monotonicity violation", "Feature indices:", dim1,
                    ",", dim2, "Min upper triangle diff:", upper_triangle_diff,
                    "Epsilon:", eps, "Layers:", weights_layers[i][j],
                    weights_layers[i + 1][j], weights_layers[i][j + 1]
                ]))

  if output_min is not None:
    min_weight = tf.reduce_min(weights)
    asserts.append(
        tf.Assert(
            min_weight >= output_min - eps,
            data=[
                "Lower bound violation.", "output_min:", output_min,
                "Smallest weight:", min_weight, "Epsilon:", eps, "Weights:",
                weights
            ]))

  if output_max is not None:
    max_weight = tf.reduce_max(weights)
    asserts.append(
        tf.Assert(
            max_weight <= output_max + eps,
            data=[
                "Upper bound violation.", "output_max:", output_max,
                "Largest weight:", max_weight, "Epsilon:", eps, "Weights:",
                weights
            ]))
  return asserts