in tensorflow_gan/python/train.py [0:0]
def stargan_loss(
model,
generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
losses_wargs.wasserstein_generator_loss),
discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
losses_wargs.wasserstein_discriminator_loss),
gradient_penalty_weight=10.0,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
reconstruction_loss_weight=10.0,
classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
classification_loss_weight=1.0,
classification_one_hot=True,
add_summaries=True):
"""StarGAN Loss.
Args:
model: (StarGAN) Model output of the stargan_model() function call.
generator_loss_fn: The loss function on the generator. Takes a
`StarGANModel` named tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
`StarGANModel` namedtuple.
gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
turn off gradient penalty.
gradient_penalty_epsilon: (float) A small positive number added for
numerical stability when computing the gradient norm.
gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
gradient norm. Defaults to 1.0.
gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
and the function must conform to the `tf.losses` API.
reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
classification_loss_fn: The loss function on the discriminator's ability to
classify domain of the input. Default to one-hot softmax cross entropy
loss, and the function must conform to the `tf.losses` API.
classification_loss_weight: (float) Classification loss weight. Default to
1.0.
classification_one_hot: (bool) If the label is one hot representation.
Default to True. If False, classification classification_loss_fn need to
be sigmoid cross entropy loss instead.
add_summaries: (bool) Add the loss to the summary
Returns:
GANLoss namedtuple where we have generator loss and discriminator loss.
Raises:
ValueError: If input StarGANModel.input_data_domain_label does not have rank
2, or dimension 2 is not defined.
"""
def _classification_loss_helper(true_labels, predict_logits, scope_name):
"""Classification Loss Function Helper.
Args:
true_labels: Tensor of shape [batch_size, num_domains] representing the
label where each row is an one-hot vector.
predict_logits: Tensor of shape [batch_size, num_domains] representing the
predicted label logit, which is UNSCALED output from the NN.
scope_name: (string) Name scope of the loss component.
Returns:
Single scalar tensor representing the classification loss.
"""
with tf.compat.v1.name_scope(
scope_name, values=(true_labels, predict_logits)):
loss = classification_loss_fn(
onehot_labels=true_labels, logits=predict_logits)
if not classification_one_hot:
loss = tf.reduce_sum(input_tensor=loss, axis=1)
loss = tf.reduce_mean(input_tensor=loss)
if add_summaries:
tf.compat.v1.summary.scalar(scope_name, loss)
return loss
# Check input shape.
model.input_data_domain_label.shape.assert_has_rank(2)
model.input_data_domain_label.shape[1:].assert_is_fully_defined()
# Adversarial Loss.
generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
# Gradient Penalty.
if _use_aux_loss(gradient_penalty_weight):
gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
losses_wargs.wasserstein_gradient_penalty)
discriminator_loss += gradient_penalty_fn(
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries) * gradient_penalty_weight
# Reconstruction Loss.
reconstruction_loss = reconstruction_loss_fn(model.input_data,
model.reconstructed_data)
generator_loss += reconstruction_loss * reconstruction_loss_weight
if add_summaries:
tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)
# Classification Loss.
generator_loss += _classification_loss_helper(
true_labels=model.generated_data_domain_target,
predict_logits=model.discriminator_generated_data_domain_predication,
scope_name='generator_classification_loss') * classification_loss_weight
discriminator_loss += _classification_loss_helper(
true_labels=model.input_data_domain_label,
predict_logits=model.discriminator_input_data_domain_predication,
scope_name='discriminator_classification_loss'
) * classification_loss_weight
return namedtuples.GANLoss(generator_loss, discriminator_loss)