def main()

in example_zoo/tensorflow/probability/grammar_vae/trainer/grammar_vae.py [0:0]


def main(argv):
  del argv  # unused
  if tf.io.gfile.exists(FLAGS.model_dir):
    tf.compat.v1.logging.warning(
        "Warning: deleting old log directory at {}".format(FLAGS.model_dir))
    tf.io.gfile.rmtree(FLAGS.model_dir)
  tf.io.gfile.makedirs(FLAGS.model_dir)
  tf.compat.v1.enable_eager_execution()

  grammar = SmilesGrammar()
  synthetic_data_distribution = ProbabilisticGrammar(
      grammar=grammar, latent_size=FLAGS.latent_size, num_units=FLAGS.num_units)

  print("Random examples from synthetic data distribution:")
  for _ in range(5):
    productions = synthetic_data_distribution()
    string = grammar.convert_to_string(productions)
    print(string)

  probabilistic_grammar = ProbabilisticGrammar(
      grammar=grammar, latent_size=FLAGS.latent_size, num_units=FLAGS.num_units)
  probabilistic_grammar_variational = ProbabilisticGrammarVariational(
      latent_size=FLAGS.latent_size)

  checkpoint = tf.train.Checkpoint(
      synthetic_data_distribution=synthetic_data_distribution,
      probabilistic_grammar=probabilistic_grammar,
      probabilistic_grammar_variational=probabilistic_grammar_variational)
  global_step = tf.compat.v1.train.get_or_create_global_step()
  optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate)
  writer = tf.compat.v2.summary.create_file_writer(FLAGS.model_dir)
  writer.set_as_default()

  start_time = time.time()
  for step in range(FLAGS.max_steps):
    productions = synthetic_data_distribution()
    with tf.GradientTape() as tape:
      # Sample from amortized variational distribution and record its trace.
      with ed.tape() as variational_tape:
        _ = probabilistic_grammar_variational(productions)

      # Set model trace to take on the data's values and the sample from the
      # variational distribution.
      values = {"latent_code": variational_tape["latent_code_posterior"]}
      values.update({"production_" + str(t): production for t, production
                     in enumerate(tf.unstack(productions, axis=1))})
      with ed.tape() as model_tape:
        with ed.interception(make_value_setter(**values)):
          _ = probabilistic_grammar()

      # Compute the ELBO given the variational sample, averaged over the batch
      # size and the number of time steps (number of productions). Although the
      # ELBO per data point sums over time steps, we average in order to have a
      # value that remains on the same scale across batches.
      log_likelihood = 0.
      for name, rv in six.iteritems(model_tape):
        if name.startswith("production"):
          log_likelihood += rv.distribution.log_prob(rv.value)

      kl = tfp.distributions.kl_divergence(
          variational_tape["latent_code_posterior"].distribution,
          model_tape["latent_code"].distribution)

      timesteps = tf.cast(productions.shape[1], dtype=tf.float32)
      elbo = tf.reduce_mean(input_tensor=log_likelihood - kl) / timesteps
      loss = -elbo
      with tf.compat.v2.summary.record_if(
          lambda: tf.math.equal(0, global_step % 500)):
        tf.compat.v2.summary.scalar(
            "log_likelihood",
            tf.reduce_mean(input_tensor=log_likelihood) / timesteps,
            step=global_step)
        tf.compat.v2.summary.scalar(
            "kl", tf.reduce_mean(input_tensor=kl) / timesteps, step=global_step)
        tf.compat.v2.summary.scalar("elbo", elbo, step=global_step)

    variables = (probabilistic_grammar.variables
                 + probabilistic_grammar_variational.variables)
    grads = tape.gradient(loss, variables)
    grads_and_vars = zip(grads, variables)
    optimizer.apply_gradients(grads_and_vars, global_step)

    if step % 500 == 0:
      duration = time.time() - start_time
      print("Step: {:>3d} Loss: {:.3f} ({:.3f} sec)".format(
          step, loss, duration))
      checkpoint.save(file_prefix=FLAGS.model_dir)