def torsion_regularizer()

in tensorflow_lattice/python/lattice_lib.py [0:0]


def torsion_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):
  """Returns Torsion regularization loss for `Lattice` layer.

  Lattice torsion regularizer penalizes how much the lattice function twists
  from side-to-side (see
  [publication](http://jmlr.org/papers/v17/15-243.html)).

  Consider a 3 x 2 lattice with weights `w`:

  ```
  w[3]-----w[4]-----w[5]
    |        |        |
    |        |        |
  w[0]-----w[1]-----w[2]
  ```

  In this case, the torsion regularizer is defined as:

  ```
  l1 * (|w[4] + w[0] - w[3] - w[1]| + |w[5] + w[1] - w[4] - w[2]|) +
  l2 * ((w[4] + w[0] - w[3] - w[1])^2 + (w[5] + w[1] - w[4] - w[2])^2)
  ```

  Arguments:
    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.
    lattice_sizes: List or tuple of integers which represents lattice sizes.
    l1: l1 regularization amount. Either single float or list or tuple of floats
      to specify different regularization amount per dimension.
    l2: l2 regularization amount. Either single float or list or tuple of floats
      to specify different regularization amount per dimension. The amount for
      the interaction term between i and j is the corresponding product of each
      per feature amount.

  Returns:
    Laplacian regularization loss.
  """
  rank = len(lattice_sizes)
  if rank == 1 or (not l1 and not l2):
    return 0.0

  # If regularization amount is given as single float assume same amount for
  # every dimension.
  if l1 and not isinstance(l1, (list, tuple)):
    l1 = [math.sqrt(l1)] * rank
  if l2 and not isinstance(l2, (list, tuple)):
    l2 = [math.sqrt(l2)] * rank

  if weights.shape[1] > 1:
    lattice_sizes = lattice_sizes + [int(weights.shape[1])]
    rank += 1
    if l1:
      l1 = l1 + [0.0]
    if l2:
      l2 = l2 + [0.0]
  weights = tf.reshape(weights, shape=lattice_sizes)

  result = tf.constant(0.0, shape=[], dtype=weights.dtype)
  for i in range(rank - 1):
    for j in range(i + 1, rank):
      if ((not l1 or not l1[i] or not l1[j]) and
          (not l2 or not l2[i] or not l2[j])):
        continue
      if j == 1:
        planes = weights
      else:
        # Transpose so dimensions i and j become first in order to simplify
        # indexing and be able to merge all other dimensions into 1 for better
        # TPU performance.
        permut = [p for p in range(rank)]
        permut[0], permut[i] = permut[i], permut[0]
        permut[1], permut[j] = permut[j], permut[1]
        planes = tf.transpose(weights, perm=permut)
      planes = tf.reshape(
          planes, shape=[lattice_sizes[i], lattice_sizes[j], -1])

      a00 = planes[0:-1, 0:-1]
      a01 = planes[0:-1, 1:]
      a10 = planes[1:, 0:-1]
      a11 = planes[1:, 1:]
      torsion = a00 + a11 - a01 - a10

      if l1:
        result += tf.reduce_sum(tf.abs(torsion)) * l1[i] * l1[j]
      if l2:
        result += tf.reduce_sum(tf.square(torsion)) * l2[i] * l2[j]
  return result