in example_zoo/tensorflow/probability/vq_vae/trainer/vq_vae.py [0:0]
def main(argv):
del argv # unused
FLAGS.activation = getattr(tf.nn, FLAGS.activation)
if tf.io.gfile.exists(FLAGS.model_dir):
tf.compat.v1.logging.warn("Deleting old log directory at {}".format(
FLAGS.model_dir))
tf.io.gfile.rmtree(FLAGS.model_dir)
tf.io.gfile.makedirs(FLAGS.model_dir)
with tf.Graph().as_default():
# TODO(b/113163167): Speed up and tune hyperparameters for Bernoulli MNIST.
(images, _, handle,
training_iterator, heldout_iterator) = build_input_pipeline(
FLAGS.data_dir, FLAGS.batch_size, heldout_size=10000,
mnist_type=FLAGS.mnist_type)
encoder = make_encoder(FLAGS.base_depth,
FLAGS.activation,
FLAGS.latent_size,
FLAGS.code_size)
decoder = make_decoder(FLAGS.base_depth,
FLAGS.activation,
FLAGS.latent_size * FLAGS.code_size,
IMAGE_SHAPE)
vector_quantizer = VectorQuantizer(FLAGS.num_codes, FLAGS.code_size)
codes = encoder(images)
nearest_codebook_entries, one_hot_assignments = vector_quantizer(codes)
codes_straight_through = codes + tf.stop_gradient(
nearest_codebook_entries - codes)
decoder_distribution = decoder(codes_straight_through)
reconstructed_images = decoder_distribution.mean()
reconstruction_loss = -tf.reduce_mean(
input_tensor=decoder_distribution.log_prob(images))
commitment_loss = tf.reduce_mean(
input_tensor=tf.square(codes -
tf.stop_gradient(nearest_codebook_entries)))
commitment_loss = add_ema_control_dependencies(
vector_quantizer,
one_hot_assignments,
codes,
commitment_loss,
FLAGS.decay)
prior_dist = tfd.Multinomial(
total_count=1.0, logits=tf.zeros([FLAGS.latent_size, FLAGS.num_codes]))
prior_loss = -tf.reduce_mean(
input_tensor=tf.reduce_sum(
input_tensor=prior_dist.log_prob(one_hot_assignments), axis=1))
loss = reconstruction_loss + FLAGS.beta * commitment_loss + prior_loss
# Upper bound marginal negative log-likelihood as prior loss +
# reconstruction loss.
marginal_nll = prior_loss + reconstruction_loss
tf.compat.v1.summary.scalar("losses/total_loss", loss)
tf.compat.v1.summary.scalar("losses/reconstruction_loss",
reconstruction_loss)
tf.compat.v1.summary.scalar("losses/prior_loss", prior_loss)
tf.compat.v1.summary.scalar("losses/commitment_loss",
FLAGS.beta * commitment_loss)
# Decode samples from a uniform prior for visualization.
prior_samples = tf.reduce_sum(
input_tensor=tf.expand_dims(prior_dist.sample(10), -1) *
tf.reshape(vector_quantizer.codebook,
[1, 1, FLAGS.num_codes, FLAGS.code_size]),
axis=2)
decoded_distribution_given_random_prior = decoder(prior_samples)
random_images = decoded_distribution_given_random_prior.mean()
# Perform inference by minimizing the loss function.
optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate)
train_op = optimizer.minimize(loss)
summary = tf.compat.v1.summary.merge_all()
init = tf.compat.v1.global_variables_initializer()
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
summary_writer = tf.compat.v1.summary.FileWriter(FLAGS.model_dir,
sess.graph)
sess.run(init)
# Run the training loop.
train_handle = sess.run(training_iterator.string_handle())
heldout_handle = sess.run(heldout_iterator.string_handle())
for step in range(FLAGS.max_steps):
start_time = time.time()
_, loss_value = sess.run([train_op, loss],
feed_dict={handle: train_handle})
duration = time.time() - start_time
if step % 100 == 0:
marginal_nll_val = sess.run(marginal_nll,
feed_dict={handle: heldout_handle})
print("Step: {:>3d} Training Loss: {:.3f} Heldout NLL: {:.3f} "
"({:.3f} sec)".format(step, loss_value, marginal_nll_val,
duration))
# Update the events file.
summary_str = sess.run(summary, feed_dict={handle: train_handle})
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
# Periodically save a checkpoint and visualize model progress.
if (step + 1) % FLAGS.viz_steps == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.model_dir, "model.ckpt")
saver.save(sess, checkpoint_file, global_step=step)
# Visualize inputs and model reconstructions from the training set.
images_val, reconstructions_val, random_images_val = sess.run(
(images, reconstructed_images, random_images),
feed_dict={handle: train_handle})
visualize_training(images_val,
reconstructions_val,
random_images_val,
log_dir=FLAGS.model_dir,
prefix="step{:05d}_train".format(step))
# Visualize inputs and model reconstructions from the validation set.
heldout_images_val, heldout_reconstructions_val = sess.run(
(images, reconstructed_images),
feed_dict={handle: heldout_handle})
visualize_training(heldout_images_val,
heldout_reconstructions_val,
None,
log_dir=FLAGS.model_dir,
prefix="step{:05d}_validation".format(step))