in tensorflow_gan/python/losses/losses_impl.py [0:0]
def wasserstein_gradient_penalty(
real_data,
generated_data,
generator_inputs,
discriminator_fn,
discriminator_scope,
epsilon=1e-10,
target=1.0,
one_sided=False,
weights=1.0,
scope=None,
loss_collection=tf.compat.v1.GraphKeys.LOSSES,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""The gradient penalty for the Wasserstein discriminator loss.
See `Improved Training of Wasserstein GANs`
(https://arxiv.org/abs/1704.00028) for more details.
Args:
real_data: Real data.
generated_data: Output of the generator.
generator_inputs: Exact argument to pass to the generator, which is used
as optional conditioning to the discriminator.
discriminator_fn: A discriminator function that conforms to TF-GAN API.
discriminator_scope: If not `None`, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when
computing the gradient norm.
target: Optional Python number or `Tensor` indicating the target value of
gradient norm. Defaults to 1.0.
one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894
is used. Defaults to `False`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data` and `generated_data`, and must be broadcastable to
them (i.e., all dimensions must be either `1`, or the same as the
corresponding dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
Raises:
ValueError: If the rank of data Tensors is unknown.
RuntimeError: If TensorFlow is executing eagerly.
"""
if tf.executing_eagerly():
raise RuntimeError('Can\'t use `tf.gradient` when executing eagerly.')
with tf.compat.v1.name_scope(scope, 'wasserstein_gradient_penalty',
(real_data, generated_data)) as scope:
real_data = tf.convert_to_tensor(value=real_data)
generated_data = tf.convert_to_tensor(value=generated_data)
if real_data.shape.ndims is None:
raise ValueError('`real_data` can\'t have unknown rank.')
if generated_data.shape.ndims is None:
raise ValueError('`generated_data` can\'t have unknown rank.')
differences = generated_data - real_data
batch_size = (tf.compat.dimension_value(differences.shape.dims[0]) or
tf.shape(input=differences)[0])
alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
alpha = tf.random.uniform(shape=alpha_shape)
interpolates = real_data + (alpha * differences)
with tf.compat.v1.name_scope(
''): # Clear scope so update ops are added properly.
# Reuse variables if variables already exists.
with tf.compat.v1.variable_scope(
discriminator_scope, 'gpenalty_dscope',
reuse=tf.compat.v1.AUTO_REUSE):
disc_interpolates = discriminator_fn(interpolates, generator_inputs)
if isinstance(disc_interpolates, tuple):
# ACGAN case: disc outputs more than one tensor
disc_interpolates = disc_interpolates[0]
gradients = tf.gradients(ys=disc_interpolates, xs=interpolates)[0]
gradient_squares = tf.reduce_sum(
input_tensor=tf.square(gradients),
axis=list(range(1, gradients.shape.ndims)))
# Propagate shape information, if possible.
if isinstance(batch_size, int):
gradient_squares.set_shape([
batch_size] + gradient_squares.shape.as_list()[1:])
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = tf.sqrt(gradient_squares + epsilon)
penalties = slopes / target - 1.0
if one_sided:
penalties = tf.maximum(0., penalties)
penalties_squared = tf.square(penalties)
penalty = tf.compat.v1.losses.compute_weighted_loss(
penalties_squared,
weights,
scope=scope,
loss_collection=loss_collection,
reduction=reduction)
if add_summaries:
tf.compat.v1.summary.scalar('gradient_penalty_loss', penalty)
return penalty