def infogan_model()

in tensorflow_gan/python/train.py [0:0]


def infogan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    unstructured_generator_inputs,
    structured_generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator'):
  """Returns an InfoGAN model outputs and variables.

  See https://arxiv.org/abs/1606.03657 for more details.

  Args:
    generator_fn: A python lambda that takes a list of Tensors as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
      `logits` are in the range [-inf, inf], and `distribution_list` is a list
      of Tensorflow distributions representing the predicted noise distribution
      of the ith structure noise.
    real_data: A Tensor representing the real data.
    unstructured_generator_inputs: A list of Tensors to the generator.
      These tensors represent the unstructured noise or conditioning.
    structured_generator_inputs: A list of Tensors to the generator.
      These tensors must have high mutual information with the recognizer.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    An InfoGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    ValueError: If the discriminator output is malformed.
    ValueError: If TF is executing eagerly.
  """
  if tf.executing_eagerly():
    raise ValueError('`tfgan.infogan_model` doesn\'t work when executing '
                     'eagerly.')

  # Create models
  with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
    unstructured_generator_inputs = _convert_tensor_or_l_or_d(
        unstructured_generator_inputs)
    structured_generator_inputs = _convert_tensor_or_l_or_d(
        structured_generator_inputs)
    generator_inputs = (
        unstructured_generator_inputs + structured_generator_inputs)
    generated_data = generator_fn(generator_inputs)
  with tf.compat.v1.variable_scope(discriminator_scope) as disc_scope:
    dis_gen_outputs, predicted_distributions = discriminator_fn(
        generated_data, generator_inputs)
  _validate_distributions(predicted_distributions, structured_generator_inputs)
  with tf.compat.v1.variable_scope(disc_scope, reuse=True):
    real_data = tf.convert_to_tensor(value=real_data)
    dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

  if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
    raise ValueError(
        'Generator output shape (%s) must be the same shape as real data '
        '(%s).' % (generated_data.get_shape(), real_data.get_shape()))

  # Get model-specific variables.
  generator_variables = contrib.get_trainable_variables(
      gen_scope)
  discriminator_variables = contrib.get_trainable_variables(
      disc_scope)

  return namedtuples.InfoGANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data,
      dis_real_outputs,
      dis_gen_outputs,
      discriminator_variables,
      disc_scope,
      lambda x, y: discriminator_fn(x, y)[0],  # conform to non-InfoGAN API
      structured_generator_inputs,
      predicted_distributions,
      discriminator_fn)