in tensorflow_gan/python/train.py [0:0]
def stargan_model(generator_fn,
discriminator_fn,
input_data,
input_data_domain_label,
generator_scope='Generator',
discriminator_scope='Discriminator'):
"""Returns a StarGAN model outputs and variables.
See https://arxiv.org/abs/1711.09020 for more details.
Args:
generator_fn: A python lambda that takes `inputs` and `targets` as inputs
and returns 'generated_data' as the transformed version of `input` based
on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
num_domains), and `generated_data` has the same shape as `input`.
discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
inputs and returns a tuple (`source_prediction`, `domain_prediction`).
`source_prediction` represents the source(real/generated) prediction by
the discriminator, and `domain_prediction` represents the domain
prediction/classification by the discriminator. `source_prediction` has
shape (n) and `domain_prediction` has shape (n, num_domains).
input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
the real input images.
input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
num_domains) representing the domain label associated with the real
images.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
Returns:
StarGANModel nametuple return the tensor that are needed to compute the
loss.
Raises:
ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
defined in every dimensions.
ValueError: If TF is executing eagerly.
"""
if tf.executing_eagerly():
raise ValueError('`tfgan.stargan_model` doesn\'t work when executing '
'eagerly.')
# Convert to tensor.
input_data = _convert_tensor_or_l_or_d(input_data)
input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
# Convert list of tensor to a single tensor if applicable.
if isinstance(input_data, (list, tuple)):
input_data = tf.concat([tf.convert_to_tensor(value=x) for x in input_data],
0)
if isinstance(input_data_domain_label, (list, tuple)):
input_data_domain_label = tf.concat(
[tf.convert_to_tensor(value=x) for x in input_data_domain_label], 0)
# Get batch_size, num_domains from the labels.
input_data_domain_label.shape.assert_has_rank(2)
input_data_domain_label.shape.assert_is_fully_defined()
batch_size, num_domains = input_data_domain_label.shape.as_list()
# Transform input_data to random target domains.
with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
generated_data_domain_target = generate_stargan_random_domain_target(
batch_size, num_domains)
generated_data = generator_fn(input_data, generated_data_domain_target)
# Transform generated_data back to the original input_data domain.
with tf.compat.v1.variable_scope(generator_scope, reuse=True):
reconstructed_data = generator_fn(generated_data, input_data_domain_label)
# Predict source and domain for the generated_data using the discriminator.
with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
generated_data, num_domains)
# Predict source and domain for the input_data using the discriminator.
with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
input_data, num_domains)
# Collect trainable variables from the neural networks.
generator_variables = contrib.get_trainable_variables(
generator_scope)
discriminator_variables = contrib.get_trainable_variables(
discriminator_scope)
# Create the StarGANModel namedtuple.
return namedtuples.StarGANModel(
input_data=input_data,
input_data_domain_label=input_data_domain_label,
generated_data=generated_data,
generated_data_domain_target=generated_data_domain_target,
reconstructed_data=reconstructed_data,
discriminator_input_data_source_predication=disc_input_data_source_pred,
discriminator_generated_data_source_predication=disc_gen_data_source_pred,
discriminator_input_data_domain_predication=disc_input_data_domain_pred,
discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
generator_variables=generator_variables,
generator_scope=generator_scope,
generator_fn=generator_fn,
discriminator_variables=discriminator_variables,
discriminator_scope=discriminator_scope,
discriminator_fn=discriminator_fn)