def project_by_dykstra()

in tensorflow_lattice/python/lattice_lib.py [0:0]


def project_by_dykstra(weights,
                       lattice_sizes,
                       monotonicities=None,
                       unimodalities=None,
                       edgeworth_trusts=None,
                       trapezoid_trusts=None,
                       monotonic_dominances=None,
                       range_dominances=None,
                       joint_monotonicities=None,
                       joint_unimodalities=None,
                       num_iterations=1):
  """Applies dykstra's projection algorithm for monotonicity/trust constraints.

  - Returns honest projection with respect to L2 norm if num_iterations is inf.
  - Monotonicity will be violated by some small eps(num_iterations).
  - Complexity: O(num_iterations * (num_monotonic_dims + num_trust_constraints)
    * num_lattice_weights)

  Dykstra's alternating projections algorithm projects into intersection of
  several convex sets. For algorithm description itself use Google or Wiki:
  https://en.wikipedia.org/wiki/Dykstra%27s_projection_algorithm

  Here, each monotonicity constraint is split up into 2 independent convex sets
  each trust constraint is split up into 4 independent convex sets. These sets
  are then projected onto exactly (in L2 space). For more details, see the
  _project_partial_* functions.

  Args:
    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
    lattice_sizes: list or tuple of integers which represents lattice sizes.
      which correspond to weights.
    monotonicities: None or list or tuple of same length as lattice_sizes of {0,
      1} which represents monotonicity constraints per dimension. 1 stands for
      increasing (non-decreasing in fact), 0 for no monotonicity constraints.
    unimodalities: None or list or tuple of same length as lattice_sizes of {-1,
      0, 1} which represents unimodality constraints per dimension. 1 indicates
      that function first decreases then increases, -1 indicates that function
      first increases then decreases, 0 indicates no unimodality constraints.
    edgeworth_trusts: None or iterable of three-element tuples. First element is
      the index of the main (monotonic) feature. Second element is the index of
      the conditional feature. Third element is the direction of trust: 1 if
        higher values of the conditional feature should increase trust in the
        main feature and -1 otherwise.
    trapezoid_trusts: None or iterable of three-element tuples. First element is
      the index of the main (monotonic) feature. Second element is the index of
      the conditional feature. Third element is the direction of trust: 1 if
        higher values of the conditional feature should increase trust in the
        main feature and -1 otherwise.
    monotonic_dominances: None or iterable of two-element tuples. First element
      is the index of the dominant feature. Second element is the index of the
      weak feature.
    range_dominances: None or iterable of two-element tuples. First element is
      the index of the dominant feature. Second element is the index of the weak
      feature.
    joint_monotonicities: None or iterable of two-element tuples. Each tuple
      represents a pair of feature indices that require joint monotoniticity.
    joint_unimodalities: None or tuple or iterable of tuples. Each tuple
      represents indices of single group of jointly unimodal features followed
      by 'valley' or 'peak'.
    num_iterations: number of iterations of Dykstra's algorithm.

  Returns:
    Projected weights tensor of same shape as `weights`.
  """
  if num_iterations == 0:
    return weights
  if (utils.count_non_zeros(monotonicities, unimodalities) == 0 and
      not joint_monotonicities and not joint_unimodalities and
      not range_dominances):
    return weights

  units = weights.shape[1]
  if monotonicities is None:
    monotonicities = [0] * len(lattice_sizes)
  if unimodalities is None:
    unimodalities = [0] * len(lattice_sizes)
  if edgeworth_trusts is None:
    edgeworth_trusts = []
  if trapezoid_trusts is None:
    trapezoid_trusts = []
  if monotonic_dominances is None:
    monotonic_dominances = []
  if range_dominances is None:
    range_dominances = []
  if joint_monotonicities is None:
    joint_monotonicities = []
  if joint_unimodalities is None:
    joint_unimodalities = []
  if units > 1:
    lattice_sizes = lattice_sizes + [int(units)]
    monotonicities = monotonicities + [0]
    unimodalities = unimodalities + [0]

  weights = tf.reshape(weights, lattice_sizes)

  def body(iteration, weights, last_change):
    """Body of the tf.while_loop for Dykstra's projection algorithm.

    This implements Dykstra's projection algorithm and requires rolling back
    the last projection change.

    Args:
      iteration: Iteration counter tensor.
      weights: Tensor with project weights at each iteraiton.
      last_change: Dict that stores the last change in the weights after
        projecting onto the each subset of constraints.

    Returns:
      The tuple (iteration, weights, last_change) at the end of each iteration.
    """
    last_change = copy.copy(last_change)
    for dim in range(len(lattice_sizes)):
      if monotonicities[dim] == 0 and unimodalities[dim] == 0:
        continue

      for constraint_group in [0, 1]:
        # Iterate over 2 sets of constraints per dimension: even and odd.
        # Odd set exists only when there are more than 2 lattice vertices.
        if constraint_group + 1 >= lattice_sizes[dim]:
          continue

        # Rolling back last projection into current set as required by Dykstra's
        # algorithm.
        rolled_back_weights = weights - last_change[("MONOTONICITY", dim,
                                                     constraint_group)]
        weights = _project_partial_monotonicity(rolled_back_weights,
                                                lattice_sizes, monotonicities,
                                                unimodalities, dim,
                                                constraint_group)
        last_change[("MONOTONICITY", dim,
                     constraint_group)] = weights - rolled_back_weights

    for constraint in edgeworth_trusts:
      main_dim, cond_dim, _ = constraint
      for constraint_group in [(0, 0), (0, 1), (1, 0), (1, 1)]:
        if (constraint_group[0] >= lattice_sizes[main_dim] - 1 or
            constraint_group[1] >= lattice_sizes[cond_dim] - 1):
          continue

        rolled_back_weights = (
            weights - last_change[("EDGEWORTH", constraint, constraint_group)])
        weights = _project_partial_edgeworth(rolled_back_weights, lattice_sizes,
                                             constraint, constraint_group)
        last_change[("EDGEWORTH", constraint,
                     constraint_group)] = weights - rolled_back_weights

    for constraint in trapezoid_trusts:
      _, cond_dim, _ = constraint
      for constraint_group in [0, 1]:
        if constraint_group >= lattice_sizes[cond_dim] - 1:
          continue

        rolled_back_weights = (
            weights - last_change[("TRAPEZOID", constraint, constraint_group)])
        weights = _project_partial_trapezoid(rolled_back_weights, lattice_sizes,
                                             constraint, constraint_group)
        last_change[("TRAPEZOID", constraint,
                     constraint_group)] = weights - rolled_back_weights

    for constraint in monotonic_dominances:
      dominant_dim, weak_dim = constraint
      for constraint_group in itertools.product([0, 1], [0, 1], [0, 1]):
        if (constraint_group[0] >= lattice_sizes[dominant_dim] - 1 or
            constraint_group[1] >= lattice_sizes[weak_dim] - 1):
          continue

        rolled_back_weights = weights - last_change[
            ("MONOTONIC_DOMINANCE", constraint, constraint_group)]
        weights = _project_partial_monotonic_dominance(rolled_back_weights,
                                                       lattice_sizes,
                                                       constraint,
                                                       constraint_group)
        last_change[("MONOTONIC_DOMINANCE", constraint,
                     constraint_group)] = weights - rolled_back_weights

    for constraint in range_dominances:
      dominant_dim, weak_dim = constraint
      dom_dim_idx = range(lattice_sizes[dominant_dim])
      weak_dim_idx = range(lattice_sizes[weak_dim])
      for constraint_group in itertools.product(dom_dim_idx, weak_dim_idx):
        rolled_back_weights = weights - last_change[
            ("RANGE_DOMINANCE", constraint, constraint_group)]
        weights = _project_partial_range_dominance(rolled_back_weights,
                                                   lattice_sizes, constraint,
                                                   constraint_group)
        last_change[("RANGE_DOMINANCE", constraint,
                     constraint_group)] = weights - rolled_back_weights

    for constraint in joint_monotonicities:
      dim1, dim2 = constraint
      for constraint_group in itertools.product([0, 1], [0, 1], [0, 1]):
        if (constraint_group[0] >= lattice_sizes[dim1] - 1 or
            constraint_group[1] >= lattice_sizes[dim2] - 1):
          continue

        rolled_back_weights = weights - last_change[
            ("JOINT_MONOTONICITY", constraint, constraint_group)]
        weights = _project_partial_joint_monotonicity(rolled_back_weights,
                                                      lattice_sizes, constraint,
                                                      constraint_group)
        last_change[("JOINT_MONOTONICITY", constraint,
                     constraint_group)] = weights - rolled_back_weights

    for constraint in joint_unimodalities:
      dimensions = tuple(constraint[0])
      lattice_ranges = [range(lattice_sizes[dim]) for dim in dimensions]
      for vertex in itertools.product(*lattice_ranges):
        for offsets in itertools.product([-1, 1], repeat=len(dimensions)):
          # For this projection constraint group is represented by pair: vertex,
          # offsets.
          projection_key = ("JOINT_UNIMODALITY", dimensions, vertex, offsets)
          if projection_key in last_change:
            rolled_back_weights = weights - last_change[projection_key]
          else:
            rolled_back_weights = weights
          projected_weights = _project_partial_joint_unimodality(
              weights=rolled_back_weights,
              lattice_sizes=lattice_sizes,
              joint_unimodalities=constraint,
              vertex=vertex,
              offsets=offsets)
          if projected_weights is not None:
            weights = projected_weights
            last_change[projection_key] = weights - rolled_back_weights
    return iteration + 1, weights, last_change

  def cond(iteration, weights, last_change):
    del weights, last_change
    return tf.less(iteration, num_iterations)

  # Run the body of the loop once to find required last_change keys. The set of
  # keys in the input and output of the body of tf.while_loop must be the same.
  # The resulting ops are discarded and will not be part of the TF graph.
  zeros = tf.zeros(shape=lattice_sizes, dtype=weights.dtype)
  last_change = collections.defaultdict(lambda: zeros)
  (_, _, last_change) = body(0, weights, last_change)

  # Apply Dykstra's algorithm with tf.while_loop.
  iteration = tf.constant(0)
  last_change = {k: zeros for k in last_change}
  (_, weights, _) = tf.while_loop(cond, body, (iteration, weights, last_change))
  return tf.reshape(weights, shape=[-1, units])