in example_zoo/tensorflow/probability/generative_adversarial_network/trainer/generative_adversarial_network.py [0:0]
def main(argv):
del argv # unused
if tf.io.gfile.exists(FLAGS.model_dir):
tf.compat.v1.logging.warning(
'Warning: deleting old log directory at {}'.format(FLAGS.model_dir))
tf.io.gfile.rmtree(FLAGS.model_dir)
tf.io.gfile.makedirs(FLAGS.model_dir)
# Collapse the image data dimension for use with a fully-connected layer.
image_size = np.prod(IMAGE_SHAPE, dtype=np.int32)
if FLAGS.fake_data:
train_images = build_fake_data([10, image_size])
else:
mnist_data = mnist.read_data_sets(FLAGS.data_dir, reshape=image_size)
train_images = mnist_data.train.images
images = build_input_pipeline(train_images, FLAGS.batch_size)
# Build a Generative network. We use the Flipout Monte Carlo estimator
# for the fully-connected layers: this enables lower variance stochastic
# gradients than naive reparameterization.
with tf.compat.v1.name_scope('Generator'):
random_noise = tf.placeholder(tf.float64, shape=[None, FLAGS.hidden_size])
generative_net = tf.keras.Sequential([
tfp.layers.DenseFlipout(FLAGS.hidden_size, activation=tf.nn.relu),
tfp.layers.DenseFlipout(image_size, activation=tf.sigmoid)
])
sythetic_image = generative_net(random_noise)
# Build a Discriminative network. Define the model as a Bernoulli
# distribution parameterized by logits from a fully-connected layer.
with tf.compat.v1.name_scope('Discriminator'):
discriminative_net = tf.keras.Sequential([
tfp.layers.DenseFlipout(FLAGS.hidden_size, activation=tf.nn.relu),
tfp.layers.DenseFlipout(1)
])
logits_real = discriminative_net(images)
logits_fake = discriminative_net(sythetic_image)
labels_distribution_real = tfd.Bernoulli(logits=logits_real)
labels_distribution_fake = tfd.Bernoulli(logits=logits_fake)
# Compute the model loss for discrimator and generator, averaged over
# the batch size.
loss_real = -tf.reduce_mean(
input_tensor=labels_distribution_real.log_prob(
tf.ones_like(logits_real)))
loss_fake = -tf.reduce_mean(
input_tensor=labels_distribution_fake.log_prob(
tf.zeros_like(logits_fake)))
loss_discriminator = loss_real + loss_fake
loss_generator = -tf.reduce_mean(
input_tensor=labels_distribution_fake.log_prob(
tf.ones_like(logits_fake)))
with tf.compat.v1.name_scope('train'):
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
train_op_discriminator = optimizer.minimize(
loss_discriminator,
var_list=tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator'))
train_op_generator = optimizer.minimize(
loss_generator,
var_list=tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator'))
with tf.compat.v1.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(FLAGS.max_steps + 1):
# Iterate gradient updates on each network.
_, loss_value_d = sess.run([train_op_discriminator, loss_discriminator],
feed_dict={random_noise: build_fake_data(
[FLAGS.batch_size, FLAGS.hidden_size])})
_, loss_value_g = sess.run([train_op_generator, loss_generator],
feed_dict={random_noise: build_fake_data(
[FLAGS.batch_size, FLAGS.hidden_size])})
# Visualize some sythetic images produced by the generative network.
if step % FLAGS.viz_steps == 0:
images = sess.run(sythetic_image,
feed_dict={random_noise: build_fake_data(
[16, FLAGS.hidden_size])})
plot_generated_images(images, fname=os.path.join(
FLAGS.model_dir,
'step{:06d}_images.png'.format(step)))
print('Step: {:>3d} Loss_discriminator: {:.3f} '
'Loss_generator: {:.3f}'.format(step, loss_value_d, loss_value_g))