in tensorflow_gan/python/features/random_tensor_pool.py [0:0]
def tensor_pool(input_values,
pool_size=50,
pooling_probability=0.5,
name='tensor_pool'):
"""Queue storing input values and returning random previously stored ones.
Every time the returned `output_value` is evaluated, `input_value` is
evaluated and its value either directly returned (with
`1-pooling_probability`) or stored in the pool and a random one of the samples
currently in the pool is popped and returned. As long as the pool in not fully
filled, the input_value is always directly returned, as well as stored in the
pool. Note during inference / testing, it may be appropriate to set
`pool_size` = 0 or `pooling_probability` = 0.
Args:
input_values: An arbitrarily nested structure of `tf.Tensors`, from which to
read values to be pooled.
pool_size: An integer specifying the maximum size of the pool. Defaults to
50.
pooling_probability: A float `Tensor` specifying the probability of getting
a value from the pool, as opposed to just the current input.
name: A string prefix for the name scope for all tensorflow ops.
Returns:
A nested structure of `Tensor` objects with the same structure as
`input_values`. With the given probability, the Tensor values are either the
same as in `input_values` or a randomly chosen sample that was previously
inserted in the pool.
Raises:
ValueError: If `pool_size` is negative.
"""
pool_size = int(pool_size)
if pool_size < 0:
raise ValueError('`pool_size` is negative.')
elif pool_size == 0:
return input_values
original_input_values = input_values
input_values = tf.nest.flatten(input_values)
with tf.name_scope(
'{}_pool_queue'.format(name),
values=input_values + [pooling_probability]):
pool_queue = tf.queue.RandomShuffleQueue(
capacity=pool_size,
min_after_dequeue=0,
dtypes=[v.dtype for v in input_values],
shapes=None)
# In pseudo code this code does the following:
# if not pool_full:
# enqueue(input_values)
# return input_values
# else
# dequeue_values = dequeue_random_sample()
# enqueue(input_values)
# if rand() < pooling_probability:
# return dequeue_values
# else
# return input_values
def _get_input_value_pooled():
enqueue_op = pool_queue.enqueue(input_values)
with tf.control_dependencies([enqueue_op]):
return [tf.identity(v) for v in input_values]
def _get_random_pool_value_and_enqueue_input():
dequeue_values = _to_list(pool_queue.dequeue())
with tf.control_dependencies(dequeue_values):
enqueue_op = pool_queue.enqueue(input_values)
with tf.control_dependencies([enqueue_op]):
prob = tf.random.uniform((), dtype=tf.float32) < pooling_probability
return tf.cond(
pred=prob,
true_fn=lambda: dequeue_values,
false_fn=lambda: input_values)
output_values = _to_list(
tf.cond(
pred=pool_queue.size() < pool_size,
true_fn=_get_input_value_pooled,
false_fn=_get_random_pool_value_and_enqueue_input))
# Make sure that the shape of `output_value` is set.
for input_value, output_value in zip(input_values, output_values):
output_value.set_shape(input_value.shape)
return tf.nest.pack_sequence_as(original_input_values, output_values)