def project_all_constraints()

in tensorflow_lattice/python/pwl_calibration_lib.py [0:0]


def project_all_constraints(weights,
                            monotonicity,
                            output_min,
                            output_max,
                            output_min_constraints,
                            output_max_constraints,
                            convexity,
                            lengths,
                            num_projection_iterations=8):
  """Jointly projects into all supported constraints.

  For all combinations of constraints except the case where bounds constraints
  are specified without monotonicity constraints we properly project into
  nearest point with respect to L2 norm. For latter case we use a heuristic to
  map input point into some feasible point with no guarantees on how close this
  point is to the true projection.

  If only bounds or only monotonicity constraints are specified there will be a
  single step projection. For all other combinations of constraints we use
  num_projection_iterations iterations of Dykstra's alternating projection
  algorithm to jointly project onto all the given constraints. Dykstra's
  algorithm gives us proper projection with respect to L2 norm but approaches it
  from "wrong" side. That's why in order to ensure that constraints are strictly
  met we'll do approximate projections in the end which project strictly into
  feasible space, but it's not an exact projection with respect to the L2 norm.
  With enough iterations of the Dykstra's algorithm, the impact of such
  approximate projection should be negligible.

  With bound and convexity constraints and no specified monotonicity, this
  method does not fully satisfy the constrains. Increasing the number of
  iterations can reduce the constraint violation in such cases.

  Args:
    weights: `(num_keypoints, units)`-shape tensor which represents weights of
      PWL calibration layer.
    monotonicity: 1 for increasing, -1 for decreasing, 0 for no monotonicity
      constraints.
    output_min: Lower bound constraint of PWL calibration layer.
    output_max: Upper bound constraint of PWL calibration layer.
    output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`
      describing the constraints on the layer's minimum value.
    output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`
      describing the constraints on the layer's maximum value.
    convexity: 1 for convex, -1 for concave, 0 for no convexity constraints.
    lengths: Lengths of pieces of piecewise linear function. Needed only if
      convexity projection is specified.
    num_projection_iterations: Number of iterations of Dykstra's alternating
      projection algorithm.

  Returns:
    Projected weights tensor.
  """
  bias = weights[0:1]
  heights = weights[1:]

  def body(projection_counter, bias, heights, last_bias_change,
           last_heights_change):
    """The body of tf.while_loop implementing a step of Dykstra's projection.

    Args:
      projection_counter: The counter tensor or number at the beginning of the
        iteration.
      bias: Bias tensor at the beginning of the iteration.
      heights: Heights tensor at the beginning of the iteration.
      last_bias_change: Dict that stores the last change in the bias after
        projecting onto each subset of constraints.
      last_heights_change: Dict that stores the last change in the heights after
        projecting onto each subset of constraints.

    Returns:
      The tuple `(num_projection_counter, bias, heights, last_bias_change,
      last_heights_change)` at the end of the iteration.
    """
    last_bias_change = copy.copy(last_bias_change)
    last_heights_change = copy.copy(last_heights_change)
    num_projections = 0
    # ******************** BOUNDS *********************
    bct = BoundConstraintsType
    if output_min_constraints != bct.NONE or output_max_constraints != bct.NONE:
      rolled_back_bias = bias - last_bias_change["BOUNDS"]
      rolled_back_heights = heights - last_heights_change["BOUNDS"]
      if monotonicity != 0:
        bias, heights = _project_bounds_considering_monotonicity(
            bias=rolled_back_bias,
            heights=rolled_back_heights,
            monotonicity=monotonicity,
            output_min=output_min,
            output_max=output_max,
            output_min_constraints=output_min_constraints,
            output_max_constraints=output_max_constraints)
      else:
        bias, heights = _approximately_project_bounds_only(
            bias=rolled_back_bias,
            heights=rolled_back_heights,
            output_min=output_min,
            output_max=output_max,
            output_min_constraints=output_min_constraints,
            output_max_constraints=output_max_constraints)
      last_bias_change["BOUNDS"] = bias - rolled_back_bias
      last_heights_change["BOUNDS"] = heights - rolled_back_heights
      num_projections += 1

    # ******************** MONOTONICITY *********************
    if monotonicity != 0:
      rolled_back_heights = heights - last_heights_change["MONOTONICITY"]
      heights = _project_monotonicity(
          heights=rolled_back_heights, monotonicity=monotonicity)
      last_heights_change["MONOTONICITY"] = heights - rolled_back_heights
      num_projections += 1

    # ******************** CONVEXITY *********************
    if convexity != 0:
      if heights.shape[0] >= 2:
        rolled_back_heights = heights - last_heights_change["CONVEXITY_0"]
        heights = _project_convexity(
            heights=rolled_back_heights,
            lengths=lengths,
            convexity=convexity,
            constraint_group=0)
        last_heights_change["CONVEXITY_0"] = heights - rolled_back_heights
        num_projections += 1
      if heights.shape[0] >= 3:
        rolled_back_heights = heights - last_heights_change["CONVEXITY_1"]
        heights = _project_convexity(
            heights=rolled_back_heights,
            lengths=lengths,
            convexity=convexity,
            constraint_group=1)
        last_heights_change["CONVEXITY_1"] = heights - rolled_back_heights
        num_projections += 1

    return (projection_counter + num_projections, bias, heights,
            last_bias_change, last_heights_change)

  # Call the body of the loop once to see if Dykstra's is needed.
  # If there is only one set of projections, apply it without a loop.
  # Running the body of the loop also finds the required last_bias_change
  # and last_heights_change keys. The set of keys in the input and output of the
  # body of tf.while_loop must be the same across iterations.
  zero_bias = tf.zeros_like(bias)
  zero_heights = tf.zeros_like(heights)
  last_bias_change = collections.defaultdict(lambda: zero_bias)
  last_heights_change = collections.defaultdict(lambda: zero_heights)
  (num_projections, projected_bias, projected_heights, last_bias_change,
   last_heights_change) = body(0, bias, heights, last_bias_change,
                               last_heights_change)
  if num_projections <= 1:
    return tf.concat([projected_bias, projected_heights], axis=0)

  def cond(projection_counter, bias, heights, last_bias_change,
           last_heights_change):
    del bias, heights, last_bias_change, last_heights_change
    return tf.less(projection_counter,
                   num_projection_iterations * num_projections)

  # Apply Dykstra's algorithm with tf.while_loop.
  projection_counter = tf.constant(0)
  last_bias_change = {k: zero_bias for k in last_bias_change}
  last_heights_change = {k: zero_heights for k in last_heights_change}
  (_, bias, heights, _,
   _) = tf.while_loop(cond, body, (projection_counter, bias, heights,
                                   last_bias_change, last_heights_change))

  # Since Dykstra's algorithm is iterative in order to strictly meet constraints
  # we use approximate projection algorithm to finalize them.
  return _finalize_constraints(
      bias=bias,
      heights=heights,
      monotonicity=monotonicity,
      output_min=output_min,
      output_max=output_max,
      output_min_constraints=output_min_constraints,
      output_max_constraints=output_max_constraints,
      convexity=convexity,
      lengths=lengths)