in tensorflow_lattice/python/lattice_lib.py [0:0]
def torsion_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):
"""Returns Torsion regularization loss for `Lattice` layer.
Lattice torsion regularizer penalizes how much the lattice function twists
from side-to-side (see
[publication](http://jmlr.org/papers/v17/15-243.html)).
Consider a 3 x 2 lattice with weights `w`:
```
w[3]-----w[4]-----w[5]
| | |
| | |
w[0]-----w[1]-----w[2]
```
In this case, the torsion regularizer is defined as:
```
l1 * (|w[4] + w[0] - w[3] - w[1]| + |w[5] + w[1] - w[4] - w[2]|) +
l2 * ((w[4] + w[0] - w[3] - w[1])^2 + (w[5] + w[1] - w[4] - w[2])^2)
```
Arguments:
weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
lattice_sizes: List or tuple of integers which represents lattice sizes.
l1: l1 regularization amount. Either single float or list or tuple of floats
to specify different regularization amount per dimension.
l2: l2 regularization amount. Either single float or list or tuple of floats
to specify different regularization amount per dimension. The amount for
the interaction term between i and j is the corresponding product of each
per feature amount.
Returns:
Laplacian regularization loss.
"""
rank = len(lattice_sizes)
if rank == 1 or (not l1 and not l2):
return 0.0
# If regularization amount is given as single float assume same amount for
# every dimension.
if l1 and not isinstance(l1, (list, tuple)):
l1 = [math.sqrt(l1)] * rank
if l2 and not isinstance(l2, (list, tuple)):
l2 = [math.sqrt(l2)] * rank
if weights.shape[1] > 1:
lattice_sizes = lattice_sizes + [int(weights.shape[1])]
rank += 1
if l1:
l1 = l1 + [0.0]
if l2:
l2 = l2 + [0.0]
weights = tf.reshape(weights, shape=lattice_sizes)
result = tf.constant(0.0, shape=[], dtype=weights.dtype)
for i in range(rank - 1):
for j in range(i + 1, rank):
if ((not l1 or not l1[i] or not l1[j]) and
(not l2 or not l2[i] or not l2[j])):
continue
if j == 1:
planes = weights
else:
# Transpose so dimensions i and j become first in order to simplify
# indexing and be able to merge all other dimensions into 1 for better
# TPU performance.
permut = [p for p in range(rank)]
permut[0], permut[i] = permut[i], permut[0]
permut[1], permut[j] = permut[j], permut[1]
planes = tf.transpose(weights, perm=permut)
planes = tf.reshape(
planes, shape=[lattice_sizes[i], lattice_sizes[j], -1])
a00 = planes[0:-1, 0:-1]
a01 = planes[0:-1, 1:]
a10 = planes[1:, 0:-1]
a11 = planes[1:, 1:]
torsion = a00 + a11 - a01 - a10
if l1:
result += tf.reduce_sum(tf.abs(torsion)) * l1[i] * l1[j]
if l2:
result += tf.reduce_sum(tf.square(torsion)) * l2[i] * l2[j]
return result