in tensorflow_gan/python/features/virtual_batchnorm.py [0:0]
def __init__(self,
reference_batch,
axis=-1,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer=tf.compat.v1.initializers.zeros(),
gamma_initializer=tf.compat.v1.initializers.ones(),
beta_regularizer=None,
gamma_regularizer=None,
trainable=True,
name=None,
batch_axis=0):
"""Initialize virtual batch normalization object.
We precompute the 'mean' and 'mean squared' of the reference batch, so that
`__call__` is efficient. This means that the axis must be supplied when the
object is created, not when it is called.
We precompute 'square mean' instead of 'variance', because the square mean
can be easily adjusted on a per-example basis.
Args:
reference_batch: A minibatch tensors. This will form the reference data
from which the normalization statistics are calculated. See
https://arxiv.org/abs/1606.03498 for more details.
axis: Integer, the axis that should be normalized (typically the features
axis). For instance, after a `Convolution2D` layer with
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor. If False,
`beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can
be disabled since the scaling can be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: String, the name of the ops.
batch_axis: The axis of the batch dimension. This dimension is treated
differently in `virtual batch normalization` vs `batch normalization`.
Raises:
ValueError: If `reference_batch` has unknown dimensions at graph
construction.
ValueError: If `batch_axis` is the same as `axis`.
"""
axis = _validate_init_input_and_get_axis(reference_batch, axis)
self._epsilon = epsilon
self._beta = 0
self._gamma = 1
self._batch_axis = _validate_init_input_and_get_axis(
reference_batch, batch_axis)
if axis == self._batch_axis:
raise ValueError('`axis` and `batch_axis` cannot be the same.')
with tf.compat.v1.variable_scope(
name, 'VBN', values=[reference_batch]) as self._vs:
self._reference_batch = reference_batch
# Calculate important shapes:
# 1) Reduction axes for the reference batch
# 2) Broadcast shape, if necessary
# 3) Reduction axes for the virtual batchnormed batch
# 4) Shape for optional parameters
input_shape = self._reference_batch.shape
ndims = input_shape.ndims
reduction_axes = list(range(ndims))
del reduction_axes[axis]
self._broadcast_shape = [1] * len(input_shape)
self._broadcast_shape[axis] = input_shape.dims[axis]
self._example_reduction_axes = list(range(ndims))
del self._example_reduction_axes[max(axis, self._batch_axis)]
del self._example_reduction_axes[min(axis, self._batch_axis)]
params_shape = self._reference_batch.shape[axis]
# Determines whether broadcasting is needed. This is slightly different
# than in the `nn.batch_normalization` case, due to `batch_dim`.
self._needs_broadcasting = (
sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
# Calculate the sufficient statistics for the reference batch in a way
# that can be easily modified by additional examples.
self._ref_mean, self._ref_mean_squares = vbn_statistics(
self._reference_batch, reduction_axes)
self._ref_variance = self._ref_mean_squares - tf.square(self._ref_mean)
# Virtual batch normalization uses a weighted average between example
# statistics and the reference batch statistics.
ref_batch_size = _static_or_dynamic_batch_size(
self._reference_batch, self._batch_axis)
self._example_weight = 1. / (tf.cast(ref_batch_size, tf.float32) + 1.)
self._ref_weight = 1. - self._example_weight
# Make the variables, if necessary.
if center:
self._beta = tf.compat.v1.get_variable(
name='beta',
shape=(params_shape,),
initializer=beta_initializer,
regularizer=beta_regularizer,
trainable=trainable)
if scale:
self._gamma = tf.compat.v1.get_variable(
name='gamma',
shape=(params_shape,),
initializer=gamma_initializer,
regularizer=gamma_regularizer,
trainable=trainable)