def main()

in example_zoo/tensorflow/probability/vq_vae/trainer/vq_vae.py [0:0]


def main(argv):
  del argv  # unused
  FLAGS.activation = getattr(tf.nn, FLAGS.activation)
  if tf.io.gfile.exists(FLAGS.model_dir):
    tf.compat.v1.logging.warn("Deleting old log directory at {}".format(
        FLAGS.model_dir))
    tf.io.gfile.rmtree(FLAGS.model_dir)
  tf.io.gfile.makedirs(FLAGS.model_dir)

  with tf.Graph().as_default():
    # TODO(b/113163167): Speed up and tune hyperparameters for Bernoulli MNIST.
    (images, _, handle,
     training_iterator, heldout_iterator) = build_input_pipeline(
         FLAGS.data_dir, FLAGS.batch_size, heldout_size=10000,
         mnist_type=FLAGS.mnist_type)

    encoder = make_encoder(FLAGS.base_depth,
                           FLAGS.activation,
                           FLAGS.latent_size,
                           FLAGS.code_size)
    decoder = make_decoder(FLAGS.base_depth,
                           FLAGS.activation,
                           FLAGS.latent_size * FLAGS.code_size,
                           IMAGE_SHAPE)
    vector_quantizer = VectorQuantizer(FLAGS.num_codes, FLAGS.code_size)

    codes = encoder(images)
    nearest_codebook_entries, one_hot_assignments = vector_quantizer(codes)
    codes_straight_through = codes + tf.stop_gradient(
        nearest_codebook_entries - codes)
    decoder_distribution = decoder(codes_straight_through)
    reconstructed_images = decoder_distribution.mean()

    reconstruction_loss = -tf.reduce_mean(
        input_tensor=decoder_distribution.log_prob(images))
    commitment_loss = tf.reduce_mean(
        input_tensor=tf.square(codes -
                               tf.stop_gradient(nearest_codebook_entries)))
    commitment_loss = add_ema_control_dependencies(
        vector_quantizer,
        one_hot_assignments,
        codes,
        commitment_loss,
        FLAGS.decay)
    prior_dist = tfd.Multinomial(
        total_count=1.0, logits=tf.zeros([FLAGS.latent_size, FLAGS.num_codes]))
    prior_loss = -tf.reduce_mean(
        input_tensor=tf.reduce_sum(
            input_tensor=prior_dist.log_prob(one_hot_assignments), axis=1))

    loss = reconstruction_loss + FLAGS.beta * commitment_loss + prior_loss
    # Upper bound marginal negative log-likelihood as prior loss +
    # reconstruction loss.
    marginal_nll = prior_loss + reconstruction_loss

    tf.compat.v1.summary.scalar("losses/total_loss", loss)
    tf.compat.v1.summary.scalar("losses/reconstruction_loss",
                                reconstruction_loss)
    tf.compat.v1.summary.scalar("losses/prior_loss", prior_loss)
    tf.compat.v1.summary.scalar("losses/commitment_loss",
                                FLAGS.beta * commitment_loss)

    # Decode samples from a uniform prior for visualization.
    prior_samples = tf.reduce_sum(
        input_tensor=tf.expand_dims(prior_dist.sample(10), -1) *
        tf.reshape(vector_quantizer.codebook,
                   [1, 1, FLAGS.num_codes, FLAGS.code_size]),
        axis=2)
    decoded_distribution_given_random_prior = decoder(prior_samples)
    random_images = decoded_distribution_given_random_prior.mean()

    # Perform inference by minimizing the loss function.
    optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate)
    train_op = optimizer.minimize(loss)

    summary = tf.compat.v1.summary.merge_all()
    init = tf.compat.v1.global_variables_initializer()
    saver = tf.compat.v1.train.Saver()
    with tf.compat.v1.Session() as sess:
      summary_writer = tf.compat.v1.summary.FileWriter(FLAGS.model_dir,
                                                       sess.graph)
      sess.run(init)

      # Run the training loop.
      train_handle = sess.run(training_iterator.string_handle())
      heldout_handle = sess.run(heldout_iterator.string_handle())
      for step in range(FLAGS.max_steps):
        start_time = time.time()
        _, loss_value = sess.run([train_op, loss],
                                 feed_dict={handle: train_handle})
        duration = time.time() - start_time
        if step % 100 == 0:
          marginal_nll_val = sess.run(marginal_nll,
                                      feed_dict={handle: heldout_handle})
          print("Step: {:>3d} Training Loss: {:.3f} Heldout NLL: {:.3f} "
                "({:.3f} sec)".format(step, loss_value, marginal_nll_val,
                                      duration))

          # Update the events file.
          summary_str = sess.run(summary, feed_dict={handle: train_handle})
          summary_writer.add_summary(summary_str, step)
          summary_writer.flush()

        # Periodically save a checkpoint and visualize model progress.
        if (step + 1) % FLAGS.viz_steps == 0 or (step + 1) == FLAGS.max_steps:
          checkpoint_file = os.path.join(FLAGS.model_dir, "model.ckpt")
          saver.save(sess, checkpoint_file, global_step=step)

          # Visualize inputs and model reconstructions from the training set.
          images_val, reconstructions_val, random_images_val = sess.run(
              (images, reconstructed_images, random_images),
              feed_dict={handle: train_handle})
          visualize_training(images_val,
                             reconstructions_val,
                             random_images_val,
                             log_dir=FLAGS.model_dir,
                             prefix="step{:05d}_train".format(step))

          # Visualize inputs and model reconstructions from the validation set.
          heldout_images_val, heldout_reconstructions_val = sess.run(
              (images, reconstructed_images),
              feed_dict={handle: heldout_handle})
          visualize_training(heldout_images_val,
                             heldout_reconstructions_val,
                             None,
                             log_dir=FLAGS.model_dir,
                             prefix="step{:05d}_validation".format(step))