in tensorflow_gan/python/train.py [0:0]
def infogan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
unstructured_generator_inputs,
structured_generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator'):
"""Returns an InfoGAN model outputs and variables.
See https://arxiv.org/abs/1606.03657 for more details.
Args:
generator_fn: A python lambda that takes a list of Tensors as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
`logits` are in the range [-inf, inf], and `distribution_list` is a list
of Tensorflow distributions representing the predicted noise distribution
of the ith structure noise.
real_data: A Tensor representing the real data.
unstructured_generator_inputs: A list of Tensors to the generator.
These tensors represent the unstructured noise or conditioning.
structured_generator_inputs: A list of Tensors to the generator.
These tensors must have high mutual information with the recognizer.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
Returns:
An InfoGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
ValueError: If the discriminator output is malformed.
ValueError: If TF is executing eagerly.
"""
if tf.executing_eagerly():
raise ValueError('`tfgan.infogan_model` doesn\'t work when executing '
'eagerly.')
# Create models
with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
unstructured_generator_inputs = _convert_tensor_or_l_or_d(
unstructured_generator_inputs)
structured_generator_inputs = _convert_tensor_or_l_or_d(
structured_generator_inputs)
generator_inputs = (
unstructured_generator_inputs + structured_generator_inputs)
generated_data = generator_fn(generator_inputs)
with tf.compat.v1.variable_scope(discriminator_scope) as disc_scope:
dis_gen_outputs, predicted_distributions = discriminator_fn(
generated_data, generator_inputs)
_validate_distributions(predicted_distributions, structured_generator_inputs)
with tf.compat.v1.variable_scope(disc_scope, reuse=True):
real_data = tf.convert_to_tensor(value=real_data)
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.get_shape(), real_data.get_shape()))
# Get model-specific variables.
generator_variables = contrib.get_trainable_variables(
gen_scope)
discriminator_variables = contrib.get_trainable_variables(
disc_scope)
return namedtuples.InfoGANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
dis_real_outputs,
dis_gen_outputs,
discriminator_variables,
disc_scope,
lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
structured_generator_inputs,
predicted_distributions,
discriminator_fn)