in tensorflow_lattice/python/lattice_lib.py [0:0]
def laplacian_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):
"""Returns Laplacian regularization loss for `Lattice` layer.
Laplacian regularizer penalizes the difference between adjacent vertices in
multi-cell lattice (see
[publication](http://jmlr.org/papers/v17/15-243.html)).
Consider a 3 x 2 lattice with weights `w`:
```
w[3]-----w[4]-----w[5]
| | |
| | |
w[0]-----w[1]-----w[2]
```
where the number at each node represents the weight index.
In this case, the laplacian regularizer is defined as:
```
l1[0] * (|w[1] - w[0]| + |w[2] - w[1]| +
|w[4] - w[3]| + |w[5] - w[4]|) +
l1[1] * (|w[3] - w[0]| + |w[4] - w[1]| + |w[5] - w[2]|) +
l2[0] * ((w[1] - w[0])^2 + (w[2] - w[1])^2 +
(w[4] - w[3])^2 + (w[5] - w[4])^2) +
l2[1] * ((w[3] - w[0])^2 + (w[4] - w[1])^2 + (w[5] - w[2])^2)
```
Arguments:
weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
lattice_sizes: List or tuple of integers which represents lattice sizes.
l1: l1 regularization amount. Either single float or list or tuple of floats
to specify different regularization amount per dimension.
l2: l2 regularization amount. Either single float or list or tuple of floats
to specify different regularization amount per dimension.
Returns:
Laplacian regularization loss.
"""
if not l1 and not l2:
return 0.0
rank = len(lattice_sizes)
# If regularization amount is given as single float assume same amount for
# every dimension.
if l1 and not isinstance(l1, (list, tuple)):
l1 = [l1] * rank
if l2 and not isinstance(l2, (list, tuple)):
l2 = [l2] * rank
if weights.shape[1] > 1:
lattice_sizes = lattice_sizes + [int(weights.shape[1])]
rank += 1
if l1:
l1 = l1 + [0.0]
if l2:
l2 = l2 + [0.0]
weights = tf.reshape(weights, shape=lattice_sizes)
result = tf.constant(0.0, shape=[], dtype=weights.dtype)
for dim in range(rank):
if (not l1 or not l1[dim]) and (not l2 or not l2[dim]):
continue
if dim > 0:
# Transpose so current dimension becomes first one in order to simplify
# indexing and be able to merge all other dimensions into 1 for better TPU
# performance.
permut = [p for p in range(rank)]
permut[0], permut[dim] = permut[dim], permut[0]
slices = tf.transpose(weights, perm=permut)
else:
slices = weights
slices = tf.reshape(slices, shape=[lattice_sizes[dim], -1])
diff = slices[1:] - slices[0:-1]
if l1:
result += tf.reduce_sum(tf.abs(diff)) * l1[dim]
if l2:
result += tf.reduce_sum(tf.square(diff)) * l2[dim]
return result