in tensorflow_gan/python/features/normalization.py [0:0]
def instance_norm(inputs,
center=True,
scale=True,
epsilon=1e-6,
activation_fn=None,
param_initializers=None,
reuse=None,
outputs_collections=None,
trainable=True,
data_format=DATA_FORMAT_NHWC,
scope=None):
"""Functional interface for the instance normalization layer.
Reference: https://arxiv.org/abs/1607.08022.
"Instance Normalization: The Missing Ingredient for Fast Stylization"
Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
Args:
inputs: A tensor with 2 or more dimensions, where the first dimension has
`batch_size`. The normalization is over all but the last dimension if
`data_format` is `NHWC` and the second dimension if `data_format` is
`NCHW`.
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `tf.nn.relu`), this can
be disabled since the scaling can be done by the next layer.
epsilon: Small float added to variance to avoid dividing by zero.
activation_fn: Activation function, default set to None to skip it and
maintain a linear activation.
param_initializers: Optional initializers for beta, gamma, moving mean and
moving variance.
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
outputs_collections: Collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
data_format: A string. `NHWC` (default) and `NCHW` are supported.
scope: Optional scope for `variable_scope`.
Returns:
A `Tensor` representing the output of the operation.
Raises:
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
ValueError: If the rank of `inputs` is undefined.
ValueError: If rank or channels dimension of `inputs` is undefined.
"""
inputs = tf.convert_to_tensor(value=inputs)
inputs_shape = inputs.shape
inputs_rank = inputs.shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
with tf.compat.v1.variable_scope(
scope, 'InstanceNorm', [inputs], reuse=reuse):
if data_format == DATA_FORMAT_NCHW:
reduction_axis = 1
# For NCHW format, rather than relying on implicit broadcasting, we
# explicitly reshape the params to params_shape_broadcast when computing
# the moments and the batch normalization.
params_shape_broadcast = list(
[1, tf.compat.dimension_value(inputs_shape[1])] +
[1 for _ in range(2, inputs_rank)])
else:
reduction_axis = inputs_rank - 1
params_shape_broadcast = None
moments_axes = list(range(inputs_rank))
del moments_axes[reduction_axis]
del moments_axes[0]
params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
if not params_shape.is_fully_defined():
raise ValueError('Inputs %s has undefined channels dimension %s.' % (
inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
dtype = inputs.dtype.base_dtype
if param_initializers is None:
param_initializers = {}
if center:
beta_initializer = param_initializers.get(
'beta', tf.compat.v1.initializers.zeros())
beta = tf.compat.v1.get_variable(
name='beta',
shape=params_shape,
dtype=dtype,
initializer=beta_initializer,
trainable=trainable)
if params_shape_broadcast:
beta = tf.reshape(beta, params_shape_broadcast)
if scale:
gamma_initializer = param_initializers.get(
'gamma', tf.compat.v1.initializers.ones())
gamma = tf.compat.v1.get_variable(
name='gamma',
shape=params_shape,
dtype=dtype,
initializer=gamma_initializer,
trainable=trainable)
if params_shape_broadcast:
gamma = tf.reshape(gamma, params_shape_broadcast)
# Calculate the moments (instance activations).
mean, variance = tf.nn.moments(x=inputs, axes=moments_axes, keepdims=True)
# Compute instance normalization.
outputs = tf.nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon, name='instancenorm')
if activation_fn is not None:
outputs = activation_fn(outputs)
if outputs_collections:
tf.compat.v1.add_to_collections(outputs_collections, outputs)
return outputs