in tensorflow_gan/python/train.py [0:0]
def gan_loss(
# GANModel.
model,
# Loss functions.
generator_loss_fn=tuple_losses.wasserstein_generator_loss,
discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss,
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
tensor_pool_fn=None,
# Options.
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=True):
"""Returns losses necessary to train generator and discriminator.
Args:
model: A GANModel tuple.
generator_loss_fn: The loss function on the generator. Takes a GANModel
tuple. If it also takes `reduction` or `add_summaries`, it will be
passed those values as well. All TF-GAN loss functions have these
arguments.
discriminator_loss_fn: The loss function on the discriminator. Takes a
GANModel tuple. If it also takes `reduction` or `add_summaries`, it will
be passed those values as well. All TF-GAN loss functions have these
arguments.
gradient_penalty_weight: If not `None`, must be a non-negative Python number
or Tensor indicating how much to weight the gradient penalty. See
https://arxiv.org/pdf/1704.00028.pdf for more details.
gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
small positive value used by the gradient penalty function for numerical
stability. Note some applications will need to increase this value to
avoid NaNs.
gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
number or `Tensor` indicating the target value of gradient norm. See the
CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
gradient_penalty_one_sided: If `True`, penalty proposed in
https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
mutual_information_penalty_weight: If not `None`, must be a non-negative
Python number or Tensor indicating how much to weight the mutual
information penalty. See https://arxiv.org/abs/1606.03657 for more
details.
aux_cond_generator_weight: If not None: add a classification loss as in
https://arxiv.org/abs/1610.09585
aux_cond_discriminator_weight: If not None: add a classification loss as in
https://arxiv.org/abs/1610.09585
tensor_pool_fn: A function that takes (generated_data, generator_inputs),
stores them in an internal pool and returns previous stored
(generated_data, generator_inputs). For example
`tfgan.features.tensor_pool`. Defaults to None (not using tensor pool).
reduction: A `tf.losses.Reduction` to apply to loss, if the loss takes an
argument called `reduction`. Otherwise, this is ignored.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
regularization losses.
Raises:
ValueError: If any of the auxiliary loss weights is provided and negative.
ValueError: If `mutual_information_penalty_weight` is provided, but the
`model` isn't an `InfoGANModel`.
"""
# Validate arguments.
gradient_penalty_weight = _validate_aux_loss_weight(
gradient_penalty_weight, 'gradient_penalty_weight')
mutual_information_penalty_weight = _validate_aux_loss_weight(
mutual_information_penalty_weight, 'infogan_weight')
aux_cond_generator_weight = _validate_aux_loss_weight(
aux_cond_generator_weight, 'aux_cond_generator_weight')
aux_cond_discriminator_weight = _validate_aux_loss_weight(
aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
# Verify configuration for mutual information penalty
if (_use_aux_loss(mutual_information_penalty_weight) and
not isinstance(model, namedtuples.InfoGANModel)):
raise ValueError(
'When `mutual_information_penalty_weight` is provided, `model` must be '
'an `InfoGANModel`. Instead, was %s.' % type(model))
# Verify configuration for mutual auxiliary condition loss (ACGAN).
if ((_use_aux_loss(aux_cond_generator_weight) or
_use_aux_loss(aux_cond_discriminator_weight)) and
not isinstance(model, namedtuples.ACGANModel)):
raise ValueError(
'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
type(model))
# Optionally create pooled model.
if tensor_pool_fn:
pooled_model = tensor_pool_adjusted_model(model, tensor_pool_fn)
else:
pooled_model = model
# Create standard losses with optional kwargs, if the loss functions accept
# them.
def _optional_kwargs(fn, possible_kwargs):
"""Returns a kwargs dictionary of valid kwargs for a given function."""
if inspect.getargspec(fn).keywords is not None:
return possible_kwargs
actual_args = inspect.getargspec(fn).args
actual_kwargs = {}
for k, v in possible_kwargs.items():
if k in actual_args:
actual_kwargs[k] = v
return actual_kwargs
possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries}
gen_loss = generator_loss_fn(
model, **_optional_kwargs(generator_loss_fn, possible_kwargs))
dis_loss = discriminator_loss_fn(
pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tuple_losses.wasserstein_gradient_penalty(
pooled_model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
reduction=reduction,
add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
gen_info_loss = tuple_losses.mutual_information_penalty(
model, reduction=reduction, add_summaries=add_summaries)
if tensor_pool_fn is None:
dis_info_loss = gen_info_loss
else:
dis_info_loss = tuple_losses.mutual_information_penalty(
pooled_model, reduction=reduction, add_summaries=add_summaries)
gen_loss += mutual_information_penalty_weight * gen_info_loss
dis_loss += mutual_information_penalty_weight * dis_info_loss
if _use_aux_loss(aux_cond_generator_weight):
ac_gen_loss = tuple_losses.acgan_generator_loss(
model, reduction=reduction, add_summaries=add_summaries)
gen_loss += aux_cond_generator_weight * ac_gen_loss
if _use_aux_loss(aux_cond_discriminator_weight):
ac_disc_loss = tuple_losses.acgan_discriminator_loss(
pooled_model, reduction=reduction, add_summaries=add_summaries)
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
# Gathers auxiliary losses.
if model.generator_scope:
gen_reg_loss = tf.compat.v1.losses.get_regularization_loss(
model.generator_scope.name)
else:
gen_reg_loss = 0
if model.discriminator_scope:
dis_reg_loss = tf.compat.v1.losses.get_regularization_loss(
model.discriminator_scope.name)
else:
dis_reg_loss = 0
return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)