in tensorflow_gan/python/tpu/normalization_ops.py [0:0]
def accumulated_moments_for_inference(mean, variance, is_training):
"""Use accumulated statistics for moments during inference.
After training the user is responsible for filling the accumulators with the
actual values.
Args:
mean: Tensor of shape [num_channels] with the mean of the current batch.
variance: Tensor of shape [num_channels] with the variance of the current
batch.
is_training: Boolean, wheather to construct ops for training or inference
graph.
Returns:
Tuple of (mean, variance) to use. This can the same as the inputs.
"""
variable_collections = [
tf.compat.v1.GraphKeys.MODEL_VARIABLES,
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
]
with tf.compat.v1.variable_scope('accu', values=[mean, variance]):
# Create variables for accumulating batch statistic and use them during
# inference. The ops for filling the accumulators must be created and run
# before eval. See docstring above.
accu_mean = tf.compat.v1.get_variable(
'accu_mean',
shape=mean.shape,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
collections=variable_collections)
accu_variance = tf.compat.v1.get_variable(
'accu_variance',
shape=variance.shape,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
collections=variable_collections)
accu_counter = tf.compat.v1.get_variable(
'accu_counter',
shape=[],
initializer=tf.compat.v1.initializers.constant(1e-12),
trainable=False,
collections=variable_collections)
update_accus = tf.compat.v1.get_variable(
'update_accus',
shape=[],
dtype=tf.int32,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
collections=variable_collections)
mean = tf.identity(mean, 'mean')
variance = tf.identity(variance, 'variance')
if is_training:
return mean, variance
logging.debug('Using accumulated moments.')
# Return the accumulated batch statistics and add current batch statistics
# to accumulators if update_accus variables equals 1.
def update_accus_fn():
return tf.group([
tf.compat.v1.assign_add(accu_mean, mean),
tf.compat.v1.assign_add(accu_variance, variance),
tf.compat.v1.assign_add(accu_counter, 1),
])
dep = tf.cond(
pred=tf.equal(update_accus, 1),
true_fn=update_accus_fn,
false_fn=tf.no_op)
with tf.control_dependencies([dep]):
return accu_mean / accu_counter, accu_variance / accu_counter