def main()

in example_zoo/tensorflow/probability/generative_adversarial_network/trainer/generative_adversarial_network.py [0:0]


def main(argv):
  del argv  # unused
  if tf.io.gfile.exists(FLAGS.model_dir):
    tf.compat.v1.logging.warning(
        'Warning: deleting old log directory at {}'.format(FLAGS.model_dir))
    tf.io.gfile.rmtree(FLAGS.model_dir)
  tf.io.gfile.makedirs(FLAGS.model_dir)

  # Collapse the image data dimension for use with a fully-connected layer.
  image_size = np.prod(IMAGE_SHAPE, dtype=np.int32)
  if FLAGS.fake_data:
    train_images = build_fake_data([10, image_size])
  else:
    mnist_data = mnist.read_data_sets(FLAGS.data_dir, reshape=image_size)
    train_images = mnist_data.train.images

  images = build_input_pipeline(train_images, FLAGS.batch_size)

  # Build a Generative network. We use the Flipout Monte Carlo estimator
  # for the fully-connected layers: this enables lower variance stochastic
  # gradients than naive reparameterization.
  with tf.compat.v1.name_scope('Generator'):
    random_noise = tf.placeholder(tf.float64, shape=[None, FLAGS.hidden_size])
    generative_net = tf.keras.Sequential([
        tfp.layers.DenseFlipout(FLAGS.hidden_size, activation=tf.nn.relu),
        tfp.layers.DenseFlipout(image_size, activation=tf.sigmoid)
    ])
    sythetic_image = generative_net(random_noise)

  # Build a Discriminative network. Define the model as a Bernoulli
  # distribution parameterized by logits from a fully-connected layer.
  with tf.compat.v1.name_scope('Discriminator'):
    discriminative_net = tf.keras.Sequential([
        tfp.layers.DenseFlipout(FLAGS.hidden_size, activation=tf.nn.relu),
        tfp.layers.DenseFlipout(1)
    ])
    logits_real = discriminative_net(images)
    logits_fake = discriminative_net(sythetic_image)
    labels_distribution_real = tfd.Bernoulli(logits=logits_real)
    labels_distribution_fake = tfd.Bernoulli(logits=logits_fake)

  # Compute the model loss for discrimator and generator, averaged over
  # the batch size.
  loss_real = -tf.reduce_mean(
      input_tensor=labels_distribution_real.log_prob(
          tf.ones_like(logits_real)))
  loss_fake = -tf.reduce_mean(
      input_tensor=labels_distribution_fake.log_prob(
          tf.zeros_like(logits_fake)))
  loss_discriminator = loss_real + loss_fake
  loss_generator = -tf.reduce_mean(
      input_tensor=labels_distribution_fake.log_prob(
          tf.ones_like(logits_fake)))

  with tf.compat.v1.name_scope('train'):
    optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    train_op_discriminator = optimizer.minimize(
        loss_discriminator,
        var_list=tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator'))
    train_op_generator = optimizer.minimize(
        loss_generator,
        var_list=tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator'))

  with tf.compat.v1.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(FLAGS.max_steps + 1):
      # Iterate gradient updates on each network.
      _, loss_value_d = sess.run([train_op_discriminator, loss_discriminator],
                                 feed_dict={random_noise: build_fake_data(
                                     [FLAGS.batch_size, FLAGS.hidden_size])})
      _, loss_value_g = sess.run([train_op_generator, loss_generator],
                                 feed_dict={random_noise: build_fake_data(
                                     [FLAGS.batch_size, FLAGS.hidden_size])})

      # Visualize some sythetic images produced by the generative network.
      if step % FLAGS.viz_steps == 0:
        images = sess.run(sythetic_image,
                          feed_dict={random_noise: build_fake_data(
                              [16, FLAGS.hidden_size])})

        plot_generated_images(images, fname=os.path.join(
            FLAGS.model_dir,
            'step{:06d}_images.png'.format(step)))

        print('Step: {:>3d} Loss_discriminator: {:.3f} '
              'Loss_generator: {:.3f}'.format(step, loss_value_d, loss_value_g))