def stargan_model()

in tensorflow_gan/python/train.py [0:0]


def stargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator'):
  """Returns a StarGAN model outputs and variables.

  See https://arxiv.org/abs/1711.09020 for more details.

  Args:
    generator_fn: A python lambda that takes `inputs` and `targets` as inputs
      and returns 'generated_data' as the transformed version of `input` based
      on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
      num_domains), and `generated_data` has the same shape as `input`.
    discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
      inputs and returns a tuple (`source_prediction`, `domain_prediction`).
      `source_prediction` represents the source(real/generated) prediction by
      the discriminator, and `domain_prediction` represents the domain
      prediction/classification by the discriminator. `source_prediction` has
      shape (n) and `domain_prediction` has shape (n, num_domains).
    input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
      the real input images.
    input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
      num_domains) representing the domain label associated with the real
      images.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    StarGANModel nametuple return the tensor that are needed to compute the
    loss.

  Raises:
    ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
      defined in every dimensions.
    ValueError: If TF is executing eagerly.
  """
  if tf.executing_eagerly():
    raise ValueError('`tfgan.stargan_model` doesn\'t work when executing '
                     'eagerly.')

  # Convert to tensor.
  input_data = _convert_tensor_or_l_or_d(input_data)
  input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)

  # Convert list of tensor to a single tensor if applicable.
  if isinstance(input_data, (list, tuple)):
    input_data = tf.concat([tf.convert_to_tensor(value=x) for x in input_data],
                           0)
  if isinstance(input_data_domain_label, (list, tuple)):
    input_data_domain_label = tf.concat(
        [tf.convert_to_tensor(value=x) for x in input_data_domain_label], 0)

  # Get batch_size, num_domains from the labels.
  input_data_domain_label.shape.assert_has_rank(2)
  input_data_domain_label.shape.assert_is_fully_defined()
  batch_size, num_domains = input_data_domain_label.shape.as_list()

  # Transform input_data to random target domains.
  with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
    generated_data_domain_target = generate_stargan_random_domain_target(
        batch_size, num_domains)
    generated_data = generator_fn(input_data, generated_data_domain_target)

  # Transform generated_data back to the original input_data domain.
  with tf.compat.v1.variable_scope(generator_scope, reuse=True):
    reconstructed_data = generator_fn(generated_data, input_data_domain_label)

  # Predict source and domain for the generated_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
    disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
        generated_data, num_domains)

  # Predict source and domain for the input_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
    disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
        input_data, num_domains)

  # Collect trainable variables from the neural networks.
  generator_variables = contrib.get_trainable_variables(
      generator_scope)
  discriminator_variables = contrib.get_trainable_variables(
      discriminator_scope)

  # Create the StarGANModel namedtuple.
  return namedtuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=input_data_domain_label,
      generated_data=generated_data,
      generated_data_domain_target=generated_data_domain_target,
      reconstructed_data=reconstructed_data,
      discriminator_input_data_source_predication=disc_input_data_source_pred,
      discriminator_generated_data_source_predication=disc_gen_data_source_pred,
      discriminator_input_data_domain_predication=disc_input_data_domain_pred,
      discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=discriminator_variables,
      discriminator_scope=discriminator_scope,
      discriminator_fn=discriminator_fn)