in tensorflow_lattice/python/lattice_lib.py [0:0]
def project_by_dykstra(weights,
lattice_sizes,
monotonicities=None,
unimodalities=None,
edgeworth_trusts=None,
trapezoid_trusts=None,
monotonic_dominances=None,
range_dominances=None,
joint_monotonicities=None,
joint_unimodalities=None,
num_iterations=1):
"""Applies dykstra's projection algorithm for monotonicity/trust constraints.
- Returns honest projection with respect to L2 norm if num_iterations is inf.
- Monotonicity will be violated by some small eps(num_iterations).
- Complexity: O(num_iterations * (num_monotonic_dims + num_trust_constraints)
* num_lattice_weights)
Dykstra's alternating projections algorithm projects into intersection of
several convex sets. For algorithm description itself use Google or Wiki:
https://en.wikipedia.org/wiki/Dykstra%27s_projection_algorithm
Here, each monotonicity constraint is split up into 2 independent convex sets
each trust constraint is split up into 4 independent convex sets. These sets
are then projected onto exactly (in L2 space). For more details, see the
_project_partial_* functions.
Args:
weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
lattice_sizes: list or tuple of integers which represents lattice sizes.
which correspond to weights.
monotonicities: None or list or tuple of same length as lattice_sizes of {0,
1} which represents monotonicity constraints per dimension. 1 stands for
increasing (non-decreasing in fact), 0 for no monotonicity constraints.
unimodalities: None or list or tuple of same length as lattice_sizes of {-1,
0, 1} which represents unimodality constraints per dimension. 1 indicates
that function first decreases then increases, -1 indicates that function
first increases then decreases, 0 indicates no unimodality constraints.
edgeworth_trusts: None or iterable of three-element tuples. First element is
the index of the main (monotonic) feature. Second element is the index of
the conditional feature. Third element is the direction of trust: 1 if
higher values of the conditional feature should increase trust in the
main feature and -1 otherwise.
trapezoid_trusts: None or iterable of three-element tuples. First element is
the index of the main (monotonic) feature. Second element is the index of
the conditional feature. Third element is the direction of trust: 1 if
higher values of the conditional feature should increase trust in the
main feature and -1 otherwise.
monotonic_dominances: None or iterable of two-element tuples. First element
is the index of the dominant feature. Second element is the index of the
weak feature.
range_dominances: None or iterable of two-element tuples. First element is
the index of the dominant feature. Second element is the index of the weak
feature.
joint_monotonicities: None or iterable of two-element tuples. Each tuple
represents a pair of feature indices that require joint monotoniticity.
joint_unimodalities: None or tuple or iterable of tuples. Each tuple
represents indices of single group of jointly unimodal features followed
by 'valley' or 'peak'.
num_iterations: number of iterations of Dykstra's algorithm.
Returns:
Projected weights tensor of same shape as `weights`.
"""
if num_iterations == 0:
return weights
if (utils.count_non_zeros(monotonicities, unimodalities) == 0 and
not joint_monotonicities and not joint_unimodalities and
not range_dominances):
return weights
units = weights.shape[1]
if monotonicities is None:
monotonicities = [0] * len(lattice_sizes)
if unimodalities is None:
unimodalities = [0] * len(lattice_sizes)
if edgeworth_trusts is None:
edgeworth_trusts = []
if trapezoid_trusts is None:
trapezoid_trusts = []
if monotonic_dominances is None:
monotonic_dominances = []
if range_dominances is None:
range_dominances = []
if joint_monotonicities is None:
joint_monotonicities = []
if joint_unimodalities is None:
joint_unimodalities = []
if units > 1:
lattice_sizes = lattice_sizes + [int(units)]
monotonicities = monotonicities + [0]
unimodalities = unimodalities + [0]
weights = tf.reshape(weights, lattice_sizes)
def body(iteration, weights, last_change):
"""Body of the tf.while_loop for Dykstra's projection algorithm.
This implements Dykstra's projection algorithm and requires rolling back
the last projection change.
Args:
iteration: Iteration counter tensor.
weights: Tensor with project weights at each iteraiton.
last_change: Dict that stores the last change in the weights after
projecting onto the each subset of constraints.
Returns:
The tuple (iteration, weights, last_change) at the end of each iteration.
"""
last_change = copy.copy(last_change)
for dim in range(len(lattice_sizes)):
if monotonicities[dim] == 0 and unimodalities[dim] == 0:
continue
for constraint_group in [0, 1]:
# Iterate over 2 sets of constraints per dimension: even and odd.
# Odd set exists only when there are more than 2 lattice vertices.
if constraint_group + 1 >= lattice_sizes[dim]:
continue
# Rolling back last projection into current set as required by Dykstra's
# algorithm.
rolled_back_weights = weights - last_change[("MONOTONICITY", dim,
constraint_group)]
weights = _project_partial_monotonicity(rolled_back_weights,
lattice_sizes, monotonicities,
unimodalities, dim,
constraint_group)
last_change[("MONOTONICITY", dim,
constraint_group)] = weights - rolled_back_weights
for constraint in edgeworth_trusts:
main_dim, cond_dim, _ = constraint
for constraint_group in [(0, 0), (0, 1), (1, 0), (1, 1)]:
if (constraint_group[0] >= lattice_sizes[main_dim] - 1 or
constraint_group[1] >= lattice_sizes[cond_dim] - 1):
continue
rolled_back_weights = (
weights - last_change[("EDGEWORTH", constraint, constraint_group)])
weights = _project_partial_edgeworth(rolled_back_weights, lattice_sizes,
constraint, constraint_group)
last_change[("EDGEWORTH", constraint,
constraint_group)] = weights - rolled_back_weights
for constraint in trapezoid_trusts:
_, cond_dim, _ = constraint
for constraint_group in [0, 1]:
if constraint_group >= lattice_sizes[cond_dim] - 1:
continue
rolled_back_weights = (
weights - last_change[("TRAPEZOID", constraint, constraint_group)])
weights = _project_partial_trapezoid(rolled_back_weights, lattice_sizes,
constraint, constraint_group)
last_change[("TRAPEZOID", constraint,
constraint_group)] = weights - rolled_back_weights
for constraint in monotonic_dominances:
dominant_dim, weak_dim = constraint
for constraint_group in itertools.product([0, 1], [0, 1], [0, 1]):
if (constraint_group[0] >= lattice_sizes[dominant_dim] - 1 or
constraint_group[1] >= lattice_sizes[weak_dim] - 1):
continue
rolled_back_weights = weights - last_change[
("MONOTONIC_DOMINANCE", constraint, constraint_group)]
weights = _project_partial_monotonic_dominance(rolled_back_weights,
lattice_sizes,
constraint,
constraint_group)
last_change[("MONOTONIC_DOMINANCE", constraint,
constraint_group)] = weights - rolled_back_weights
for constraint in range_dominances:
dominant_dim, weak_dim = constraint
dom_dim_idx = range(lattice_sizes[dominant_dim])
weak_dim_idx = range(lattice_sizes[weak_dim])
for constraint_group in itertools.product(dom_dim_idx, weak_dim_idx):
rolled_back_weights = weights - last_change[
("RANGE_DOMINANCE", constraint, constraint_group)]
weights = _project_partial_range_dominance(rolled_back_weights,
lattice_sizes, constraint,
constraint_group)
last_change[("RANGE_DOMINANCE", constraint,
constraint_group)] = weights - rolled_back_weights
for constraint in joint_monotonicities:
dim1, dim2 = constraint
for constraint_group in itertools.product([0, 1], [0, 1], [0, 1]):
if (constraint_group[0] >= lattice_sizes[dim1] - 1 or
constraint_group[1] >= lattice_sizes[dim2] - 1):
continue
rolled_back_weights = weights - last_change[
("JOINT_MONOTONICITY", constraint, constraint_group)]
weights = _project_partial_joint_monotonicity(rolled_back_weights,
lattice_sizes, constraint,
constraint_group)
last_change[("JOINT_MONOTONICITY", constraint,
constraint_group)] = weights - rolled_back_weights
for constraint in joint_unimodalities:
dimensions = tuple(constraint[0])
lattice_ranges = [range(lattice_sizes[dim]) for dim in dimensions]
for vertex in itertools.product(*lattice_ranges):
for offsets in itertools.product([-1, 1], repeat=len(dimensions)):
# For this projection constraint group is represented by pair: vertex,
# offsets.
projection_key = ("JOINT_UNIMODALITY", dimensions, vertex, offsets)
if projection_key in last_change:
rolled_back_weights = weights - last_change[projection_key]
else:
rolled_back_weights = weights
projected_weights = _project_partial_joint_unimodality(
weights=rolled_back_weights,
lattice_sizes=lattice_sizes,
joint_unimodalities=constraint,
vertex=vertex,
offsets=offsets)
if projected_weights is not None:
weights = projected_weights
last_change[projection_key] = weights - rolled_back_weights
return iteration + 1, weights, last_change
def cond(iteration, weights, last_change):
del weights, last_change
return tf.less(iteration, num_iterations)
# Run the body of the loop once to find required last_change keys. The set of
# keys in the input and output of the body of tf.while_loop must be the same.
# The resulting ops are discarded and will not be part of the TF graph.
zeros = tf.zeros(shape=lattice_sizes, dtype=weights.dtype)
last_change = collections.defaultdict(lambda: zeros)
(_, _, last_change) = body(0, weights, last_change)
# Apply Dykstra's algorithm with tf.while_loop.
iteration = tf.constant(0)
last_change = {k: zeros for k in last_change}
(_, weights, _) = tf.while_loop(cond, body, (iteration, weights, last_change))
return tf.reshape(weights, shape=[-1, units])