in models/official/amoeba_net/network_utils.py [0:0]
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
moving_vars='moving_vars',
activation_fn=None,
is_training=True,
data_format='NHWC',
reuse=None,
num_shards=None,
distributed_group_size=1,
scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift"
Sergey Ioffe, Christian Szegedy
Can be used as a normalizer function for conv2d and fully_connected.
Note: When is_training is True the moving_mean and moving_variance need to be
updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so
they need to be added as a dependency to the `train_op`, example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.group(*update_ops)
total_loss = control_flow_ops.with_dependencies([updates], total_loss)
One can set updates_collections=None to force the updates in place, but that
can have speed penalty, especially in distributed settings.
Args:
inputs: A tensor with 2 or more dimensions, where the first dimension has
`batch_size`. The normalization is over all but the last dimension if
`data_format` is `NHWC` and the second dimension if `data_format` is
`NCHW`.
decay: Decay for the moving average. Reasonable values for `decay` are close
to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
Lower `decay` value (recommend trying `decay`=0.9) if model experiences
reasonably good training performance but poor validation and/or test
performance.
center: If True, add offset of `beta` to normalized tensor. If False,
`beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling can be done by the next layer.
epsilon: Small float added to variance to avoid dividing by zero.
moving_vars: Name of collection created for moving variables.
activation_fn: Activation function, default set to None to skip it and
maintain a linear activation.
is_training: Whether or not the layer is in training mode. In training mode
it would accumulate the statistics of the moments into `moving_mean` and
`moving_variance` using an exponential moving average with the given
`decay`. When it is not in training mode then it would use the values of
the `moving_mean` and the `moving_variance`.
data_format: input data format. NHWC or NCHW
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
num_shards: Number of shards that participate in the global
reduction. Default is set to None, that will skip the cross replica sum in
and normalize across local examples only.
distributed_group_size: Number of replicas to normalize across in the
distributed batch normalization.
scope: Optional scope for `variable_scope`.
Returns:
A `Tensor` representing the output of the operation.
Raises:
ValueError: If the rank of `inputs` is undefined.
ValueError: If the rank of `inputs` is neither 2 or 4.
ValueError: If rank or `C` dimension of `inputs` is undefined.
"""
trainable = True
with tf.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse):
inputs = tf.convert_to_tensor(inputs)
original_shape = inputs.get_shape()
original_rank = original_shape.ndims
if original_rank is None:
raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif original_rank not in [2, 4]:
raise ValueError('Inputs %s has unsupported rank.'
' Expected 2 or 4 but got %d' % (inputs.name,
original_rank))
if original_rank == 2:
channels = inputs.get_shape()[-1].value
if channels is None:
raise ValueError('`C` dimension must be known but is None')
new_shape = [-1, 1, 1, channels]
if data_format == 'NCHW':
new_shape = [-1, channels, 1, 1]
inputs = tf.reshape(inputs, new_shape)
inputs_shape = inputs.get_shape()
if data_format == 'NHWC':
params_shape = inputs_shape[-1:]
else:
params_shape = inputs_shape[1:2]
if not params_shape.is_fully_defined():
raise ValueError('Inputs %s has undefined `C` dimension %s.' %
(inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
trainable_beta = trainable and center
collections = [tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES]
beta = contrib_framework.variable(
'beta',
params_shape,
collections=collections,
initializer=tf.zeros_initializer(),
trainable=trainable_beta)
trainable_gamma = trainable and scale
gamma = contrib_framework.variable(
'gamma',
params_shape,
collections=collections,
initializer=tf.ones_initializer(),
trainable=trainable_gamma)
# Create moving_mean and moving_variance variables and add them to the
# appropiate collections.
moving_collections = [moving_vars,
tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
tf.GraphKeys.MODEL_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]
# Disable partition setting for moving_mean and moving_variance
# as assign_moving_average op below doesn't support partitioned variable.
scope = tf.get_variable_scope()
partitioner = scope.partitioner
scope.set_partitioner(None)
moving_mean = contrib_framework.variable(
'moving_mean',
params_shape,
initializer=tf.zeros_initializer(),
trainable=False,
collections=moving_collections)
moving_variance = contrib_framework.variable(
'moving_variance',
params_shape,
initializer=tf.ones_initializer(),
trainable=False,
collections=moving_collections)
# Restore scope's partitioner setting.
scope.set_partitioner(partitioner)
# Add cross replica sum to do subset mean and variance calculation
# First compute mean and variance
if is_training:
if distributed_group_size > 1:
# Execute a distributed batch normalization
if data_format == 'NCHW':
axis = 1
else:
axis = 3
input_shape = inputs.get_shape()
inputs_dtype = inputs.dtype
inputs = tf.cast(inputs, tf.float32)
ndims = len(input_shape)
reduction_axes = [i for i in range(ndims) if i != axis]
counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
inputs, reduction_axes, keep_dims=False)
mean_ss = cross_replica_average(mean_ss, num_shards,
distributed_group_size)
variance_ss = cross_replica_average(variance_ss, num_shards,
distributed_group_size)
mean, variance = tf.nn.normalize_moments(
counts, mean_ss, variance_ss, shift=None)
outputs = tf.nn.batch_normalization(inputs, mean, variance, beta, gamma,
epsilon)
outputs = tf.cast(outputs, inputs_dtype)
else:
outputs, mean, variance = tf.nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
else:
outputs, mean, variance = tf.nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=moving_mean,
variance=moving_variance,
epsilon=epsilon,
is_training=False,
data_format=data_format)
if is_training:
update_moving_mean = moving_averages.assign_moving_average(
moving_mean,
tf.cast(mean, moving_mean.dtype),
decay,
zero_debias=False)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance,
tf.cast(variance, moving_variance.dtype),
decay,
zero_debias=False)
tf.add_to_collection('update_ops', update_moving_mean)
tf.add_to_collection('update_ops', update_moving_variance)
outputs.set_shape(inputs_shape)
if original_shape.ndims == 2:
outputs = tf.reshape(outputs, original_shape)
if activation_fn is not None:
outputs = activation_fn(outputs)
return outputs