def group_norm()

in tensorflow_gan/python/features/normalization.py [0:0]


def group_norm(inputs,
               groups=32,
               channels_axis=-1,
               reduction_axes=(-3, -2),
               center=True,
               scale=True,
               epsilon=1e-6,
               activation_fn=None,
               param_initializers=None,
               reuse=None,
               outputs_collections=None,
               trainable=True,
               scope=None,
               mean_close_to_zero=False):
  """Functional interface for the group normalization layer.

  Reference: https://arxiv.org/abs/1803.08494.

    "Group Normalization", Yuxin Wu, Kaiming He

  Args:
    inputs: A Tensor with at least 2 dimensions one which is channels. All
     shape dimensions except for batch must be fully defined.
    groups: Integer. Divide the channels into this number of groups over which
      normalization statistics are computed. This number must be commensurate
      with the number of channels in `inputs`.
    channels_axis: An integer. Specifies index of channels axis which will be
      broken into `groups`, each of which whose statistics will be computed
      across. Must be mutually exclusive with `reduction_axes`. Preferred usage
      is to specify negative integers to be agnostic as to whether a batch
      dimension is included.
    reduction_axes: Tuple of integers. Specifies dimensions over which
       statistics will be accumulated. Must be mutually exclusive with
       `channels_axis`. Statistics will not be accumulated across axes not
       specified in `reduction_axes` nor `channel_axis`. Preferred usage is to
       specify negative integers to be agnostic to whether a batch dimension is
       included.

      Some sample usage cases:
        NHWC format: channels_axis=-1, reduction_axes=[-3, -2]
        NCHW format: channels_axis=-3, reduction_axes=[-2, -1]

    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `tf.nn.relu`), this can
      be disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.
    mean_close_to_zero: The mean of `input` before ReLU will be close to zero
      when batch size >= 4k for Resnet-50 on TPU. If `True`, use
      `tf.nn.sufficient_statistics` and `tf.nn.normalize_moments` to calculate
      the variance. This is the same behavior as `fused` equals `True` in batch
      normalization. If `False`, use `tf.nn.moments` to calculate the variance.
      When `mean` is close to zero, like 1e-4, use `mean` to calculate the
      variance may have poor result due to repeated roundoff error and
      denormalization in `mean`.  When `mean` is large, like 1e2,
      sum(`input`^2) is so large that only the high-order digits of the elements
      are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate
      the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
      when `mean` is large.


  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
    ValueError: If number of groups is not commensurate with number of channels.
    ValueError: If reduction_axes or channels_axis are out of bounds.
    ValueError: If reduction_axes are not mutually exclusive with channels_axis.
  """
  # TODO(shlens): Support partially defined shapes for the inputs.
  inputs = tf.convert_to_tensor(value=inputs)

  if inputs.shape.ndims is None:
    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  if channels_axis > (inputs.shape.ndims - 1):
    raise ValueError('Axis is out of bounds.')

  # Use dynamic shape for not fully defined dimensions in the inputs.
  dyanmic_shape = tf.shape(input=inputs)
  input_shape_list = []
  for i, dim in enumerate(inputs.shape):
    if tf.compat.dimension_value(dim) is None:
      input_shape_list.append(dyanmic_shape[i])
    else:
      input_shape_list.append(dim)

  # Standardize the channels_axis to be positive and identify # of channels.
  if channels_axis < 0:
    channels_axis = inputs.shape.ndims + channels_axis
  channels = tf.compat.dimension_value(inputs.shape[channels_axis])

  if channels is None:
    raise ValueError('Inputs %s has undefined channel dimension: %d.' % (
        inputs.name, channels_axis))

  # Standardize the reduction_axes to be positive.
  reduction_axes = list(reduction_axes)
  for i in range(len(reduction_axes)):
    if reduction_axes[i] < 0:
      reduction_axes[i] += inputs.shape.ndims

  for a in reduction_axes:
    if a > inputs.shape.ndims:
      raise ValueError('Axis is out of bounds.')
    if tf.compat.dimension_value(inputs.shape[a]) is None:
      raise ValueError('Inputs %s has undefined dimensions %d.' % (
          inputs.name, a))
    if channels_axis == a:
      raise ValueError('reduction_axis must be mutually exclusive '
                       'with channels_axis')
  if groups > channels:
    raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
  if channels % groups != 0:
    raise ValueError('%d channels is not commensurate with %d groups.' %
                     (channels, groups))

  # Determine axes before channels. Some examples of common image formats:
  #  'NCHW': before = [N], after = [HW]
  #  'NHWC': before = [NHW], after = []
  axes_before_channels = input_shape_list[:channels_axis]
  axes_after_channels = input_shape_list[channels_axis+1:]

  # Manually broadcast the parameters to conform to the number of groups.
  params_shape_broadcast = ([1] * len(axes_before_channels) +
                            [groups, channels // groups] +
                            [1] * len(axes_after_channels))

  # Reshape the input by the group within the channel dimension.
  inputs_shape = (axes_before_channels + [groups, channels // groups] +
                  axes_after_channels)
  inputs = tf.reshape(inputs, inputs_shape)

  # Determine the dimensions across which moments are calculated.
  moments_axes = [channels_axis + 1]
  for a in reduction_axes:
    if a > channels_axis:
      moments_axes.append(a + 1)
    else:
      moments_axes.append(a)

  with tf.compat.v1.variable_scope(scope, 'GroupNorm', [inputs], reuse=reuse):
    # Note that the params_shape is the number of channels always.
    params_shape = [channels]

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    dtype = inputs.dtype.base_dtype
    if param_initializers is None:
      param_initializers = {}
    if center:
      beta_initializer = param_initializers.get(
          'beta', tf.compat.v1.initializers.zeros())
      beta = tf.compat.v1.get_variable(
          name='beta',
          shape=params_shape,
          dtype=dtype,
          initializer=beta_initializer,
          trainable=trainable)
      beta = tf.reshape(beta, params_shape_broadcast)

    if scale:
      gamma_initializer = param_initializers.get(
          'gamma', tf.compat.v1.initializers.ones())
      gamma = tf.compat.v1.get_variable(
          name='gamma',
          shape=params_shape,
          dtype=dtype,
          initializer=gamma_initializer,
          trainable=trainable)
      gamma = tf.reshape(gamma, params_shape_broadcast)

    # Calculate the moments.
    if mean_close_to_zero:
      # One pass algorithm returns better result when mean is close to zero.
      counts, means_ss, variance_ss, _ = tf.nn.sufficient_statistics(
          inputs, moments_axes, keepdims=True)
      mean, variance = tf.nn.normalize_moments(
          counts, means_ss, variance_ss, shift=None)
    else:
      mean, variance = tf.nn.moments(
          x=inputs, axes=moments_axes, keepdims=True)

    # Compute normalization.
    # TODO(shlens): Fix tf.nn.batch_normalization to handle the 5-D Tensor
    # appropriately so that this operation may be faster.
    gain = tf.math.rsqrt(variance + epsilon)
    offset = -mean * gain
    if gamma is not None:
      gain *= gamma
      offset *= gamma
    if beta is not None:
      offset += beta
    outputs = inputs * gain + offset

    # Collapse the groups into the channel dimension.
    outputs = tf.reshape(outputs, input_shape_list)

    if activation_fn is not None:
      outputs = activation_fn(outputs)

    if outputs_collections:
      tf.compat.v1.add_to_collections(outputs_collections, outputs)

    return outputs