def main()

in example_zoo/tensorflow/probability/deep_exponential_family/trainer/deep_exponential_family.py [0:0]


def main(argv):
  del argv  # unused
  FLAGS.layer_sizes = [int(layer_size) for layer_size in FLAGS.layer_sizes]
  if len(FLAGS.layer_sizes) != 3:
    raise NotImplementedError("Specifying fewer or more than 3 layers is not "
                              "currently available.")
  if tf.io.gfile.exists(FLAGS.model_dir):
    tf.compat.v1.logging.warning(
        "Warning: deleting old log directory at {}".format(FLAGS.model_dir))
    tf.io.gfile.rmtree(FLAGS.model_dir)
  tf.io.gfile.makedirs(FLAGS.model_dir)

  if FLAGS.fake_data:
    bag_of_words = np.random.poisson(1., size=[10, 25])
    words = [str(i) for i in range(25)]
  else:
    bag_of_words, words = load_nips2011_papers(FLAGS.data_dir)

  total_count = np.sum(bag_of_words)
  bag_of_words = tf.cast(bag_of_words, dtype=tf.float32)
  data_size, feature_size = bag_of_words.shape

  # Compute expected log-likelihood. First, sample from the variational
  # distribution; second, compute the log-likelihood given the sample.
  qw2, qw1, qw0, qz2, qz1, qz0 = deep_exponential_family_variational(
      data_size,
      feature_size,
      FLAGS.layer_sizes)

  with ed.tape() as model_tape:
    with ed.interception(make_value_setter(w2=qw2, w1=qw1, w0=qw0,
                                           z2=qz2, z1=qz1, z0=qz0)):
      posterior_predictive = deep_exponential_family(data_size,
                                                     feature_size,
                                                     FLAGS.layer_sizes,
                                                     FLAGS.shape)

  log_likelihood = posterior_predictive.distribution.log_prob(bag_of_words)
  log_likelihood = tf.reduce_sum(input_tensor=log_likelihood)
  tf.compat.v1.summary.scalar("log_likelihood", log_likelihood)

  # Compute analytic KL-divergence between variational and prior distributions.
  kl = 0.
  for rv_name, variational_rv in [("z0", qz0), ("z1", qz1), ("z2", qz2),
                                  ("w0", qw0), ("w1", qw1), ("w2", qw2)]:
    kl += tf.reduce_sum(
        input_tensor=variational_rv.distribution.kl_divergence(
            model_tape[rv_name].distribution))

  tf.compat.v1.summary.scalar("kl", kl)

  elbo = log_likelihood - kl
  tf.compat.v1.summary.scalar("elbo", elbo)
  optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate)
  train_op = optimizer.minimize(-elbo)

  sess = tf.compat.v1.Session()
  summary = tf.compat.v1.summary.merge_all()
  summary_writer = tf.compat.v1.summary.FileWriter(FLAGS.model_dir, sess.graph)
  start_time = time.time()

  sess.run(tf.compat.v1.global_variables_initializer())
  for step in range(FLAGS.max_steps):
    start_time = time.time()
    _, elbo_value = sess.run([train_op, elbo])
    if step % 500 == 0:
      duration = time.time() - start_time
      print("Step: {:>3d} Loss: {:.3f} ({:.3f} sec)".format(
          step, elbo_value, duration))
      summary_str = sess.run(summary)
      summary_writer.add_summary(summary_str, step)
      summary_writer.flush()

      # Compute perplexity of the full data set. The model's negative
      # log-likelihood of data is upper bounded by the variational objective.
      negative_log_likelihood = -elbo_value
      perplexity = np.exp(negative_log_likelihood / total_count)
      print("Negative log-likelihood <= {:0.3f}".format(
          negative_log_likelihood))
      print("Perplexity <= {:0.3f}".format(perplexity))

      # Print top 10 words for first 10 topics.
      qw0_values = sess.run(qw0)
      for k in range(min(10, FLAGS.layer_sizes[-1])):
        top_words_idx = qw0_values[k, :].argsort()[-10:][::-1]
        top_words = " ".join([words[i] for i in top_words_idx])
        print("Topic {}: {}".format(k, top_words))