def wasserstein_gradient_penalty()

in tensorflow_gan/python/losses/losses_impl.py [0:0]


def wasserstein_gradient_penalty(
    real_data,
    generated_data,
    generator_inputs,
    discriminator_fn,
    discriminator_scope,
    epsilon=1e-10,
    target=1.0,
    one_sided=False,
    weights=1.0,
    scope=None,
    loss_collection=tf.compat.v1.GraphKeys.LOSSES,
    reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    add_summaries=False):
  """The gradient penalty for the Wasserstein discriminator loss.

  See `Improved Training of Wasserstein GANs`
  (https://arxiv.org/abs/1704.00028) for more details.

  Args:
    real_data: Real data.
    generated_data: Output of the generator.
    generator_inputs: Exact argument to pass to the generator, which is used
      as optional conditioning to the discriminator.
    discriminator_fn: A discriminator function that conforms to TF-GAN API.
    discriminator_scope: If not `None`, reuse discriminators from this scope.
    epsilon: A small positive number added for numerical stability when
      computing the gradient norm.
    target: Optional Python number or `Tensor` indicating the target value of
      gradient norm. Defaults to 1.0.
    one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894
      is used. Defaults to `False`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `real_data` and `generated_data`, and must be broadcastable to
      them (i.e., all dimensions must be either `1`, or the same as the
      corresponding dimension).
    scope: The scope for the operations performed in computing the loss.
    loss_collection: collection to which this loss will be added.
    reduction: A `tf.losses.Reduction` to apply to loss.
    add_summaries: Whether or not to add summaries for the loss.

  Returns:
    A loss Tensor. The shape depends on `reduction`.

  Raises:
    ValueError: If the rank of data Tensors is unknown.
    RuntimeError: If TensorFlow is executing eagerly.
  """
  if tf.executing_eagerly():
    raise RuntimeError('Can\'t use `tf.gradient` when executing eagerly.')
  with tf.compat.v1.name_scope(scope, 'wasserstein_gradient_penalty',
                               (real_data, generated_data)) as scope:
    real_data = tf.convert_to_tensor(value=real_data)
    generated_data = tf.convert_to_tensor(value=generated_data)
    if real_data.shape.ndims is None:
      raise ValueError('`real_data` can\'t have unknown rank.')
    if generated_data.shape.ndims is None:
      raise ValueError('`generated_data` can\'t have unknown rank.')

    differences = generated_data - real_data
    batch_size = (tf.compat.dimension_value(differences.shape.dims[0]) or
                  tf.shape(input=differences)[0])
    alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
    alpha = tf.random.uniform(shape=alpha_shape)
    interpolates = real_data + (alpha * differences)

    with tf.compat.v1.name_scope(
        ''):  # Clear scope so update ops are added properly.
      # Reuse variables if variables already exists.
      with tf.compat.v1.variable_scope(
          discriminator_scope, 'gpenalty_dscope',
          reuse=tf.compat.v1.AUTO_REUSE):
        disc_interpolates = discriminator_fn(interpolates, generator_inputs)

    if isinstance(disc_interpolates, tuple):
      # ACGAN case: disc outputs more than one tensor
      disc_interpolates = disc_interpolates[0]

    gradients = tf.gradients(ys=disc_interpolates, xs=interpolates)[0]
    gradient_squares = tf.reduce_sum(
        input_tensor=tf.square(gradients),
        axis=list(range(1, gradients.shape.ndims)))
    # Propagate shape information, if possible.
    if isinstance(batch_size, int):
      gradient_squares.set_shape([
          batch_size] + gradient_squares.shape.as_list()[1:])
    # For numerical stability, add epsilon to the sum before taking the square
    # root. Note tf.norm does not add epsilon.
    slopes = tf.sqrt(gradient_squares + epsilon)
    penalties = slopes / target - 1.0
    if one_sided:
      penalties = tf.maximum(0., penalties)
    penalties_squared = tf.square(penalties)
    penalty = tf.compat.v1.losses.compute_weighted_loss(
        penalties_squared,
        weights,
        scope=scope,
        loss_collection=loss_collection,
        reduction=reduction)

    if add_summaries:
      tf.compat.v1.summary.scalar('gradient_penalty_loss', penalty)

    return penalty