def stargan_loss()

in tensorflow_gan/python/train.py [0:0]


def stargan_loss(
    model,
    generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
        losses_wargs.wasserstein_generator_loss),
    discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
        losses_wargs.wasserstein_discriminator_loss),
    gradient_penalty_weight=10.0,
    gradient_penalty_epsilon=1e-10,
    gradient_penalty_target=1.0,
    gradient_penalty_one_sided=False,
    reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
    reconstruction_loss_weight=10.0,
    classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
    classification_loss_weight=1.0,
    classification_one_hot=True,
    add_summaries=True):
  """StarGAN Loss.

  Args:
    model: (StarGAN) Model output of the stargan_model() function call.
    generator_loss_fn: The loss function on the generator. Takes a
      `StarGANModel` named tuple.
    discriminator_loss_fn: The loss function on the discriminator. Takes a
      `StarGANModel` namedtuple.
    gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
      the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
      turn off gradient penalty.
    gradient_penalty_epsilon: (float) A small positive number added for
      numerical stability when computing the gradient norm.
    gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
      gradient norm. Defaults to 1.0.
    gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
      https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
    reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
      and the function must conform to the `tf.losses` API.
    reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
    classification_loss_fn: The loss function on the discriminator's ability to
      classify domain of the input. Default to one-hot softmax cross entropy
      loss, and the function must conform to the `tf.losses` API.
    classification_loss_weight: (float) Classification loss weight. Default to
      1.0.
    classification_one_hot: (bool) If the label is one hot representation.
      Default to True. If False, classification classification_loss_fn need to
      be sigmoid cross entropy loss instead.
    add_summaries: (bool) Add the loss to the summary

  Returns:
    GANLoss namedtuple where we have generator loss and discriminator loss.

  Raises:
    ValueError: If input StarGANModel.input_data_domain_label does not have rank
    2, or dimension 2 is not defined.
  """

  def _classification_loss_helper(true_labels, predict_logits, scope_name):
    """Classification Loss Function Helper.

    Args:
      true_labels: Tensor of shape [batch_size, num_domains] representing the
        label where each row is an one-hot vector.
      predict_logits: Tensor of shape [batch_size, num_domains] representing the
        predicted label logit, which is UNSCALED output from the NN.
      scope_name: (string) Name scope of the loss component.

    Returns:
      Single scalar tensor representing the classification loss.
    """

    with tf.compat.v1.name_scope(
        scope_name, values=(true_labels, predict_logits)):

      loss = classification_loss_fn(
          onehot_labels=true_labels, logits=predict_logits)

      if not classification_one_hot:
        loss = tf.reduce_sum(input_tensor=loss, axis=1)
      loss = tf.reduce_mean(input_tensor=loss)

      if add_summaries:
        tf.compat.v1.summary.scalar(scope_name, loss)

      return loss

  # Check input shape.
  model.input_data_domain_label.shape.assert_has_rank(2)
  model.input_data_domain_label.shape[1:].assert_is_fully_defined()

  # Adversarial Loss.
  generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
  discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)

  # Gradient Penalty.
  if _use_aux_loss(gradient_penalty_weight):
    gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
        losses_wargs.wasserstein_gradient_penalty)
    discriminator_loss += gradient_penalty_fn(
        model,
        epsilon=gradient_penalty_epsilon,
        target=gradient_penalty_target,
        one_sided=gradient_penalty_one_sided,
        add_summaries=add_summaries) * gradient_penalty_weight

  # Reconstruction Loss.
  reconstruction_loss = reconstruction_loss_fn(model.input_data,
                                               model.reconstructed_data)
  generator_loss += reconstruction_loss * reconstruction_loss_weight
  if add_summaries:
    tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)

  # Classification Loss.
  generator_loss += _classification_loss_helper(
      true_labels=model.generated_data_domain_target,
      predict_logits=model.discriminator_generated_data_domain_predication,
      scope_name='generator_classification_loss') * classification_loss_weight
  discriminator_loss += _classification_loss_helper(
      true_labels=model.input_data_domain_label,
      predict_logits=model.discriminator_input_data_domain_predication,
      scope_name='discriminator_classification_loss'
  ) * classification_loss_weight

  return namedtuples.GANLoss(generator_loss, discriminator_loss)