def laplacian_regularizer()

in tensorflow_lattice/python/lattice_lib.py [0:0]


def laplacian_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):
  """Returns Laplacian regularization loss for `Lattice` layer.

  Laplacian regularizer penalizes the difference between adjacent vertices in
  multi-cell lattice (see
  [publication](http://jmlr.org/papers/v17/15-243.html)).

  Consider a 3 x 2 lattice with weights `w`:

  ```
  w[3]-----w[4]-----w[5]
    |        |        |
    |        |        |
  w[0]-----w[1]-----w[2]
  ```

  where the number at each node represents the weight index.
  In this case, the laplacian regularizer is defined as:

  ```
  l1[0] * (|w[1] - w[0]| + |w[2] - w[1]| +
           |w[4] - w[3]| + |w[5] - w[4]|) +
  l1[1] * (|w[3] - w[0]| + |w[4] - w[1]| + |w[5] - w[2]|) +

  l2[0] * ((w[1] - w[0])^2 + (w[2] - w[1])^2 +
           (w[4] - w[3])^2 + (w[5] - w[4])^2) +
  l2[1] * ((w[3] - w[0])^2 + (w[4] - w[1])^2 + (w[5] - w[2])^2)
  ```

  Arguments:
    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
    lattice_sizes: List or tuple of integers which represents lattice sizes.
    l1: l1 regularization amount. Either single float or list or tuple of floats
      to specify different regularization amount per dimension.
    l2: l2 regularization amount. Either single float or list or tuple of floats
      to specify different regularization amount per dimension.

  Returns:
    Laplacian regularization loss.
  """
  if not l1 and not l2:
    return 0.0

  rank = len(lattice_sizes)
  # If regularization amount is given as single float assume same amount for
  # every dimension.
  if l1 and not isinstance(l1, (list, tuple)):
    l1 = [l1] * rank
  if l2 and not isinstance(l2, (list, tuple)):
    l2 = [l2] * rank

  if weights.shape[1] > 1:
    lattice_sizes = lattice_sizes + [int(weights.shape[1])]
    rank += 1
    if l1:
      l1 = l1 + [0.0]
    if l2:
      l2 = l2 + [0.0]
  weights = tf.reshape(weights, shape=lattice_sizes)

  result = tf.constant(0.0, shape=[], dtype=weights.dtype)
  for dim in range(rank):
    if (not l1 or not l1[dim]) and (not l2 or not l2[dim]):
      continue
    if dim > 0:
      # Transpose so current dimension becomes first one in order to simplify
      # indexing and be able to merge all other dimensions into 1 for better TPU
      # performance.
      permut = [p for p in range(rank)]
      permut[0], permut[dim] = permut[dim], permut[0]
      slices = tf.transpose(weights, perm=permut)
    else:
      slices = weights
    slices = tf.reshape(slices, shape=[lattice_sizes[dim], -1])

    diff = slices[1:] - slices[0:-1]
    if l1:
      result += tf.reduce_sum(tf.abs(diff)) * l1[dim]
    if l2:
      result += tf.reduce_sum(tf.square(diff)) * l2[dim]
  return result