def fast_walsh_hadamard_transform()

in tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py [0:0]


def fast_walsh_hadamard_transform(x):
  """Applies the fast Walsh-Hadamard transform to a set of vectors.

  This method uses a composition of existing TensorFlow operations to implement
  the transform.

  Args:
    x: A `Tensor`. Must be of shape `[a, b]`, where `a` can be anything (not
      necessarily known), and `b` must be a power of two, not required to be
      statically known.

  Returns:
    A `Tensor` of shape `[a, b]`, where `[i, :]` is the product `x[i, :]*H`,
      where `H` is the Hadamard matrix.

  Raises:
    ValueError: If the input is not rank 2 `Tensor`, and if the second dimension
      is statically known and is not a power of two.
    OpError: If the second dimension is not statically known and is not a power
      of two. Note that in graph execution, this error is not raised during the
      execution of the Python function, but during execution of the resulting
      computation.
  """
  with tf.compat.v1.name_scope(None, 'fast_walsh_hadamard_transform'):
    # Validate input.
    x = tf.convert_to_tensor(x)
    if x.shape.ndims != 2:
      raise ValueError(
          'Number of dimensions of x must be 2. Shape of x: %s' % x.shape)

    original_x_shape = x.shape.as_list()
    dim = x.shape.as_list()[-1]

    if dim is None:  # dim is not statically known.
      dim = tf.shape(x)[-1]
      log2 = tf.cast(
          tf.math.round(
              tf.math.log(tf.cast(dim, tf.float32)) / tf.math.log(2.)),
          tf.int32)
      with tf.control_dependencies([
          tf.compat.v1.assert_equal(
              dim,
              tf.math.pow(2, log2),
              message='The dimension of x must be a power of two.'
              'Provided dimension is: %s' % dim)
      ]):
        x = tf.identity(x)
    else:  # dim is statically known.
      if not (dim and ((dim & (dim - 1)) == 0)):
        raise ValueError('The dimension of x must be a power of two. '
                         'Provided dimension is: %s' % dim)
      log2 = int(np.ceil(np.log2(dim)))
      if dim == 1:  # Equivalent to identity.
        return tf.identity(x)

    h_core = tf.constant([[1., 1.], [1., -1.]],
                         dtype=x.dtype,
                         name='hadamard_weights_2x2')
    permutation = tf.constant([0, 2, 1], name='hadamard_permutation')

    # A step of the fast Walsh-Hadamard algorithm.
    def _hadamard_step(x, dim):
      """A single step in the fast Walsh-Hadamard transform."""
      x_shape = x.shape.as_list()
      x = tf.reshape(x, [-1, 2])  # Reshape so that we have a matrix.
      x = tf.matmul(x, h_core)  # Multiply.
      x = tf.reshape(x, [-1, dim // 2, 2])  # Reshape to rank-3.
      x = tf.transpose(x, perm=permutation)  # Swap last two dimensions.
      x.set_shape(x_shape)  # Failed shape inference in tf.while_loop.
      return x

    def _fwht(x, dim, log2):
      x = tf.reshape(x, [-1, 2, dim // 2])
      # The fast Walsh-Hadamard transform.

      i = tf.constant(0)
      c = lambda i, x: tf.less(i, log2)
      b = lambda i, x: [i + 1, _hadamard_step(x, dim)]
      i, x = tf.while_loop(c, b, [i, x])
      return x

    x = tf.cond(
        tf.equal(dim, 1), lambda: tf.identity(x), lambda: _fwht(x, dim, log2))

    x = tf.reshape(x, [-1, dim])
    x /= tf.sqrt(tf.cast(dim, x.dtype))  # Normalize.
    x.set_shape(original_x_shape)  # Failed shape inference after tf.while_loop.
    return x