in tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py [0:0]
def fast_walsh_hadamard_transform(x):
"""Applies the fast Walsh-Hadamard transform to a set of vectors.
This method uses a composition of existing TensorFlow operations to implement
the transform.
Args:
x: A `Tensor`. Must be of shape `[a, b]`, where `a` can be anything (not
necessarily known), and `b` must be a power of two, not required to be
statically known.
Returns:
A `Tensor` of shape `[a, b]`, where `[i, :]` is the product `x[i, :]*H`,
where `H` is the Hadamard matrix.
Raises:
ValueError: If the input is not rank 2 `Tensor`, and if the second dimension
is statically known and is not a power of two.
OpError: If the second dimension is not statically known and is not a power
of two. Note that in graph execution, this error is not raised during the
execution of the Python function, but during execution of the resulting
computation.
"""
with tf.compat.v1.name_scope(None, 'fast_walsh_hadamard_transform'):
# Validate input.
x = tf.convert_to_tensor(x)
if x.shape.ndims != 2:
raise ValueError(
'Number of dimensions of x must be 2. Shape of x: %s' % x.shape)
original_x_shape = x.shape.as_list()
dim = x.shape.as_list()[-1]
if dim is None: # dim is not statically known.
dim = tf.shape(x)[-1]
log2 = tf.cast(
tf.math.round(
tf.math.log(tf.cast(dim, tf.float32)) / tf.math.log(2.)),
tf.int32)
with tf.control_dependencies([
tf.compat.v1.assert_equal(
dim,
tf.math.pow(2, log2),
message='The dimension of x must be a power of two.'
'Provided dimension is: %s' % dim)
]):
x = tf.identity(x)
else: # dim is statically known.
if not (dim and ((dim & (dim - 1)) == 0)):
raise ValueError('The dimension of x must be a power of two. '
'Provided dimension is: %s' % dim)
log2 = int(np.ceil(np.log2(dim)))
if dim == 1: # Equivalent to identity.
return tf.identity(x)
h_core = tf.constant([[1., 1.], [1., -1.]],
dtype=x.dtype,
name='hadamard_weights_2x2')
permutation = tf.constant([0, 2, 1], name='hadamard_permutation')
# A step of the fast Walsh-Hadamard algorithm.
def _hadamard_step(x, dim):
"""A single step in the fast Walsh-Hadamard transform."""
x_shape = x.shape.as_list()
x = tf.reshape(x, [-1, 2]) # Reshape so that we have a matrix.
x = tf.matmul(x, h_core) # Multiply.
x = tf.reshape(x, [-1, dim // 2, 2]) # Reshape to rank-3.
x = tf.transpose(x, perm=permutation) # Swap last two dimensions.
x.set_shape(x_shape) # Failed shape inference in tf.while_loop.
return x
def _fwht(x, dim, log2):
x = tf.reshape(x, [-1, 2, dim // 2])
# The fast Walsh-Hadamard transform.
i = tf.constant(0)
c = lambda i, x: tf.less(i, log2)
b = lambda i, x: [i + 1, _hadamard_step(x, dim)]
i, x = tf.while_loop(c, b, [i, x])
return x
x = tf.cond(
tf.equal(dim, 1), lambda: tf.identity(x), lambda: _fwht(x, dim, log2))
x = tf.reshape(x, [-1, dim])
x /= tf.sqrt(tf.cast(dim, x.dtype)) # Normalize.
x.set_shape(original_x_shape) # Failed shape inference after tf.while_loop.
return x