def accumulated_moments_for_inference()

in tensorflow_gan/python/tpu/normalization_ops.py [0:0]


def accumulated_moments_for_inference(mean, variance, is_training):
  """Use accumulated statistics for moments during inference.

  After training the user is responsible for filling the accumulators with the
  actual values.

  Args:
    mean: Tensor of shape [num_channels] with the mean of the current batch.
    variance: Tensor of shape [num_channels] with the variance of the current
      batch.
    is_training: Boolean, wheather to construct ops for training or inference
      graph.

  Returns:
    Tuple of (mean, variance) to use. This can the same as the inputs.
  """
  variable_collections = [
      tf.compat.v1.GraphKeys.MODEL_VARIABLES,
      tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
  ]
  with tf.compat.v1.variable_scope('accu', values=[mean, variance]):
    # Create variables for accumulating batch statistic and use them during
    # inference. The ops for filling the accumulators must be created and run
    # before eval. See docstring above.
    accu_mean = tf.compat.v1.get_variable(
        'accu_mean',
        shape=mean.shape,
        initializer=tf.compat.v1.zeros_initializer(),
        trainable=False,
        collections=variable_collections)
    accu_variance = tf.compat.v1.get_variable(
        'accu_variance',
        shape=variance.shape,
        initializer=tf.compat.v1.zeros_initializer(),
        trainable=False,
        collections=variable_collections)
    accu_counter = tf.compat.v1.get_variable(
        'accu_counter',
        shape=[],
        initializer=tf.compat.v1.initializers.constant(1e-12),
        trainable=False,
        collections=variable_collections)
    update_accus = tf.compat.v1.get_variable(
        'update_accus',
        shape=[],
        dtype=tf.int32,
        initializer=tf.compat.v1.zeros_initializer(),
        trainable=False,
        collections=variable_collections)

    mean = tf.identity(mean, 'mean')
    variance = tf.identity(variance, 'variance')

    if is_training:
      return mean, variance

    logging.debug('Using accumulated moments.')
    # Return the accumulated batch statistics and add current batch statistics
    # to accumulators if update_accus variables equals 1.
    def update_accus_fn():
      return tf.group([
          tf.compat.v1.assign_add(accu_mean, mean),
          tf.compat.v1.assign_add(accu_variance, variance),
          tf.compat.v1.assign_add(accu_counter, 1),
      ])

    dep = tf.cond(
        pred=tf.equal(update_accus, 1),
        true_fn=update_accus_fn,
        false_fn=tf.no_op)
    with tf.control_dependencies([dep]):
      return accu_mean / accu_counter, accu_variance / accu_counter