in tensorflow_gan/python/tpu/normalization_ops.py [0:0]
def standardize_batch(inputs,
is_training,
offset=None,
scale=None,
decay=0.999,
epsilon=1e-3,
data_format='NHWC',
use_moving_averages=True,
use_cross_replica_mean=None):
"""Adds TPU-enabled batch normalization layer.
Details on Batch Normalization can be found in 'Batch Normalization:
Accelerating Deep Network Training by Reducing Internal Covariate Shift',
Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167].
Note #1: This method computes the batch statistic across all TPU replicas,
thus simulating the true batch norm in the distributed setting. If one wants
to avoid the cross-replica communication set use_cross_replica_mean=False.
Note #2: When is_training is True the moving_mean and moving_variance need
to be updated in each training step. By default, the update_ops are placed
in `tf.GraphKeys.UPDATE_OPS` and they need to be added as a dependency to
the `train_op`. For example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.group(*update_ops)
total_loss = control_flow_ops.with_dependencies([updates], total_loss)
Note #3: Reasonable values for `decay` are close to 1.0, typically in the
multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying
`decay`=0.9) if model experiences reasonably good training performance but
poor validation and/or test performance.
Args:
inputs: A tensor with 2 or 4 dimensions, where the first dimension is
`batch_size`. The normalization is over all but the last dimension if
`data_format` is `NHWC`, and the second dimension if `data_format` is
`NCHW`.
is_training: Whether or not the layer is in training mode. In training
mode it would accumulate the statistics of the moments into the
`moving_mean` and `moving_variance` using an exponential moving average
with the given `decay`. When is_training=False, these variables are not
updated, and the precomputed values are used verbatim.
offset: An offset `Tensor`, often denoted `beta` in equations, or
None. If present, will be added to the normalized tensor.
scale: A scale `Tensor`, often denoted `gamma` in equations, or
`None`. If present, the scale is applied to the normalized tensor.
decay: Decay for the moving averages. See notes above for reasonable
values.
epsilon: Small float added to variance to avoid dividing by zero.
data_format: Input data format. NHWC or NCHW.
use_moving_averages: If True keep moving averages of mean and variance that
are used during inference. Otherwise use accumlators.
use_cross_replica_mean: If True add operations to do computes batch norm
statistics across all TPU cores. These ops are not compatible with other
platforms. The default (None) will only add the operations if running
on TPU.
Returns:
The normalized tensor with the same type and shape as `inputs`.
"""
if data_format not in {'NCHW', 'NHWC'}:
raise ValueError(
'Invalid data_format {}. Allowed: NCHW, NHWC.'.format(data_format))
if use_cross_replica_mean is None:
# Default to global batch norm only on TPUs.
use_cross_replica_mean = (
tpu_function.get_tpu_context().number_of_shards is not None)
logging.debug('Automatically determined use_cross_replica_mean=%s.',
use_cross_replica_mean)
inputs = tf.convert_to_tensor(value=inputs)
inputs_dtype = inputs.dtype
inputs_shape = inputs.get_shape()
num_channels = tf.compat.dimension_value(inputs.shape[-1])
if num_channels is None:
raise ValueError('`C` dimension must be known but is None')
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif inputs_rank not in [2, 4]:
raise ValueError(
'Inputs %s has unsupported rank.'
' Expected 2 or 4 but got %d' % (inputs.name, inputs_rank))
# Bring 2-D inputs into 4-D format.
if inputs_rank == 2:
new_shape = [-1, 1, 1, num_channels]
if data_format == 'NCHW':
new_shape = [-1, num_channels, 1, 1]
inputs = tf.reshape(inputs, new_shape)
if offset is not None:
offset = tf.reshape(offset, new_shape)
if scale is not None:
scale = tf.reshape(scale, new_shape)
# Execute a distributed batch normalization
axis = 1 if data_format == 'NCHW' else 3
inputs = tf.cast(inputs, tf.float32)
reduction_axes = [i for i in range(4) if i != axis]
if use_cross_replica_mean:
mean, variance = cross_replica_moments(inputs, reduction_axes)
else:
counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
inputs, reduction_axes, keepdims=False)
mean, variance = tf.nn.normalize_moments(
counts, mean_ss, variance_ss, shift=None)
if use_moving_averages:
mean, variance = moving_moments_for_inference(
mean=mean, variance=variance, is_training=is_training, decay=decay)
else:
mean, variance = accumulated_moments_for_inference(
mean=mean, variance=variance, is_training=is_training)
outputs = tf.nn.batch_normalization(
inputs,
mean=mean,
variance=variance,
offset=offset,
scale=scale,
variance_epsilon=epsilon)
outputs = tf.cast(outputs, inputs_dtype)
# Bring 2-D inputs back into 2-D format.
if inputs_rank == 2:
outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list())
outputs.set_shape(inputs_shape)
return outputs