def _cmwc_random_sequence()

in tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py [0:0]


def _cmwc_random_sequence(num_elements, seed):
  """Implements a version of the Complementary Multiply with Carry algorithm.

  http://en.wikipedia.org/wiki/Multiply-with-carry

  This implementation serves as a purely TensorFlow implementation of a fully
  deterministic source of pseudo-random number sequence. That is given a
  `Tensor` `seed`, this method will output a `Tensor` with `n` elements, that
  will produce the same sequence when evaluated (assuming the same value of the
  `Tensor` `seed`).

  This method is not particularly efficient, does not come with any guarantee of
  the period length, and should be replaced by appropriate alternative in
  TensorFlow 2.x. In a test in general colab runtime, it took ~0.5s to generate
  1 million values.

  Args:
    num_elements: A Python integer. The number of random values to be generated.
    seed: A scalar `Tensor` of type `tf.int64`.

  Returns:
    A `Tensor` of shape `(num_elements)` and dtype tf.float64, containing random
    values in the range `[0, 1)`.
  """
  if not isinstance(num_elements, int):
    raise TypeError('The num_elements argument must be a Python integer.')
  if num_elements <= 0:
    raise ValueError('The num_elements argument must be positive.')
  if not tf.is_tensor(seed) or seed.dtype != tf.int64:
    raise TypeError('The seed argument must be a tf.int64 Tensor.')

  # For better efficiency of tf.while_loop, we generate `parallelism` random
  # sequences in parallel. The specific constant (sqrt(num_elements) / 10) is
  # hand picked after simple benchmarking for large values of num_elements.
  parallelism = int(math.ceil(math.sqrt(num_elements) / 10))
  num_iters = num_elements // parallelism + 1

  # Create constants needed for the algorithm. The constants and notation
  # follows from the above reference.
  a = tf.tile(tf.constant([3636507990], tf.int64), [parallelism])
  b = tf.tile(tf.constant([2**32], tf.int64), [parallelism])
  logb_scalar = tf.constant(32, tf.int64)
  logb = tf.tile([logb_scalar], [parallelism])
  f = tf.tile(tf.constant([0], dtype=tf.int64), [parallelism])
  bits = tf.constant(0, dtype=tf.int64, name='bits')

  # TensorArray used in tf.while_loop for efficiency.
  values = tf.TensorArray(
      dtype=tf.float64, size=num_iters, element_shape=[parallelism])
  # Iteration counter.
  num = tf.constant(0, dtype=tf.int32, name='num')
  # TensorFlow constant to be used at multiple places.
  val_53 = tf.constant(53, tf.int64, name='val_53')

  # Construct initial sequence of seeds.
  # From a single input seed, we construct multiple starting seeds for the
  # sequences to be computed in parallel.
  def next_seed_fn(i, val, q):
    val = val**7 + val**6 + 1  # PRBS7.
    q = q.write(i, val)
    return i + 1, val, q

  q = tf.TensorArray(dtype=tf.int64, size=parallelism, element_shape=())
  _, _, q = tf.while_loop(lambda i, _, __: i < parallelism,
                          next_seed_fn,
                          [tf.constant(0), seed, q])
  c = q = q.stack()

  # The random sequence generation code.
  def cmwc_step(f, bits, q, c, num, values):
    """A single step of the modified CMWC algorithm."""
    t = a * q + c
    c = b - 1 - tf.bitwise.right_shift(t, logb)
    x = q = tf.bitwise.bitwise_and(t, (b - 1))
    f = tf.bitwise.bitwise_or(tf.bitwise.left_shift(f, logb), x)
    if parallelism == 1:
      f.set_shape((1,))  # Correct for failed shape inference.
    bits += logb_scalar
    def add_val(bits, f, values, num):
      new_val = tf.cast(
          tf.bitwise.bitwise_and(f, (2**val_53 - 1)),
          dtype=tf.float64) * (1 / 2**val_53)
      values = values.write(num, new_val)
      f += tf.bitwise.right_shift(f, val_53)
      bits -= val_53
      num += 1
      return bits, f, values, num
    bits, f, values, num = tf.cond(bits >= val_53,
                                   lambda: add_val(bits, f, values, num),
                                   lambda: (bits, f, values, num))
    return f, bits, q, c, num, values

  def condition(f, bits, q, c, num, values):  # pylint: disable=unused-argument
    return num < num_iters

  _, _, _, _, _, values = tf.while_loop(
      condition,
      cmwc_step,
      [f, bits, q, c, num, values],
  )

  values = tf.reshape(values.stack(), [-1])
  # We generated parallelism * num_iters random values. Take a slice of the
  # first num_elements for the requested Tensor.
  values = values[:num_elements]
  values.set_shape((num_elements,))  # Correct for failed shape inference.
  return  values