def tensor_pool()

in tensorflow_gan/python/features/random_tensor_pool.py [0:0]


def tensor_pool(input_values,
                pool_size=50,
                pooling_probability=0.5,
                name='tensor_pool'):
  """Queue storing input values and returning random previously stored ones.

  Every time the returned `output_value` is evaluated, `input_value` is
  evaluated and its value either directly returned (with
  `1-pooling_probability`) or stored in the pool and a random one of the samples
  currently in the pool is popped and returned. As long as the pool in not fully
  filled, the input_value is always directly returned, as well as stored in the
  pool. Note during inference / testing, it may be appropriate to set
  `pool_size` = 0 or `pooling_probability` = 0.

  Args:
    input_values: An arbitrarily nested structure of `tf.Tensors`, from which to
      read values to be pooled.
    pool_size: An integer specifying the maximum size of the pool. Defaults to
      50.
    pooling_probability: A float `Tensor` specifying the probability of getting
      a value from the pool, as opposed to just the current input.
    name: A string prefix for the name scope for all tensorflow ops.

  Returns:
    A nested structure of `Tensor` objects with the same structure as
    `input_values`. With the given probability, the Tensor values are either the
    same as in `input_values` or a randomly chosen sample that was previously
    inserted in the pool.

  Raises:
    ValueError: If `pool_size` is negative.
  """
  pool_size = int(pool_size)
  if pool_size < 0:
    raise ValueError('`pool_size` is negative.')
  elif pool_size == 0:
    return input_values

  original_input_values = input_values
  input_values = tf.nest.flatten(input_values)

  with tf.name_scope(
      '{}_pool_queue'.format(name),
      values=input_values + [pooling_probability]):
    pool_queue = tf.queue.RandomShuffleQueue(
        capacity=pool_size,
        min_after_dequeue=0,
        dtypes=[v.dtype for v in input_values],
        shapes=None)

    # In pseudo code this code does the following:
    # if not pool_full:
    #   enqueue(input_values)
    #   return input_values
    # else
    #   dequeue_values = dequeue_random_sample()
    #   enqueue(input_values)
    #   if rand() < pooling_probability:
    #     return dequeue_values
    #   else
    #     return input_values

    def _get_input_value_pooled():
      enqueue_op = pool_queue.enqueue(input_values)
      with tf.control_dependencies([enqueue_op]):
        return [tf.identity(v) for v in input_values]

    def _get_random_pool_value_and_enqueue_input():
      dequeue_values = _to_list(pool_queue.dequeue())
      with tf.control_dependencies(dequeue_values):
        enqueue_op = pool_queue.enqueue(input_values)
        with tf.control_dependencies([enqueue_op]):
          prob = tf.random.uniform((), dtype=tf.float32) < pooling_probability
          return tf.cond(
              pred=prob,
              true_fn=lambda: dequeue_values,
              false_fn=lambda: input_values)

    output_values = _to_list(
        tf.cond(
            pred=pool_queue.size() < pool_size,
            true_fn=_get_input_value_pooled,
            false_fn=_get_random_pool_value_and_enqueue_input))

    # Make sure that the shape of `output_value` is set.
    for input_value, output_value in zip(input_values, output_values):
      output_value.set_shape(input_value.shape)

  return tf.nest.pack_sequence_as(original_input_values, output_values)