in example_zoo/tensorflow/probability/grammar_vae/trainer/grammar_vae.py [0:0]
def main(argv):
del argv # unused
if tf.io.gfile.exists(FLAGS.model_dir):
tf.compat.v1.logging.warning(
"Warning: deleting old log directory at {}".format(FLAGS.model_dir))
tf.io.gfile.rmtree(FLAGS.model_dir)
tf.io.gfile.makedirs(FLAGS.model_dir)
tf.compat.v1.enable_eager_execution()
grammar = SmilesGrammar()
synthetic_data_distribution = ProbabilisticGrammar(
grammar=grammar, latent_size=FLAGS.latent_size, num_units=FLAGS.num_units)
print("Random examples from synthetic data distribution:")
for _ in range(5):
productions = synthetic_data_distribution()
string = grammar.convert_to_string(productions)
print(string)
probabilistic_grammar = ProbabilisticGrammar(
grammar=grammar, latent_size=FLAGS.latent_size, num_units=FLAGS.num_units)
probabilistic_grammar_variational = ProbabilisticGrammarVariational(
latent_size=FLAGS.latent_size)
checkpoint = tf.train.Checkpoint(
synthetic_data_distribution=synthetic_data_distribution,
probabilistic_grammar=probabilistic_grammar,
probabilistic_grammar_variational=probabilistic_grammar_variational)
global_step = tf.compat.v1.train.get_or_create_global_step()
optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate)
writer = tf.compat.v2.summary.create_file_writer(FLAGS.model_dir)
writer.set_as_default()
start_time = time.time()
for step in range(FLAGS.max_steps):
productions = synthetic_data_distribution()
with tf.GradientTape() as tape:
# Sample from amortized variational distribution and record its trace.
with ed.tape() as variational_tape:
_ = probabilistic_grammar_variational(productions)
# Set model trace to take on the data's values and the sample from the
# variational distribution.
values = {"latent_code": variational_tape["latent_code_posterior"]}
values.update({"production_" + str(t): production for t, production
in enumerate(tf.unstack(productions, axis=1))})
with ed.tape() as model_tape:
with ed.interception(make_value_setter(**values)):
_ = probabilistic_grammar()
# Compute the ELBO given the variational sample, averaged over the batch
# size and the number of time steps (number of productions). Although the
# ELBO per data point sums over time steps, we average in order to have a
# value that remains on the same scale across batches.
log_likelihood = 0.
for name, rv in six.iteritems(model_tape):
if name.startswith("production"):
log_likelihood += rv.distribution.log_prob(rv.value)
kl = tfp.distributions.kl_divergence(
variational_tape["latent_code_posterior"].distribution,
model_tape["latent_code"].distribution)
timesteps = tf.cast(productions.shape[1], dtype=tf.float32)
elbo = tf.reduce_mean(input_tensor=log_likelihood - kl) / timesteps
loss = -elbo
with tf.compat.v2.summary.record_if(
lambda: tf.math.equal(0, global_step % 500)):
tf.compat.v2.summary.scalar(
"log_likelihood",
tf.reduce_mean(input_tensor=log_likelihood) / timesteps,
step=global_step)
tf.compat.v2.summary.scalar(
"kl", tf.reduce_mean(input_tensor=kl) / timesteps, step=global_step)
tf.compat.v2.summary.scalar("elbo", elbo, step=global_step)
variables = (probabilistic_grammar.variables
+ probabilistic_grammar_variational.variables)
grads = tape.gradient(loss, variables)
grads_and_vars = zip(grads, variables)
optimizer.apply_gradients(grads_and_vars, global_step)
if step % 500 == 0:
duration = time.time() - start_time
print("Step: {:>3d} Loss: {:.3f} ({:.3f} sec)".format(
step, loss, duration))
checkpoint.save(file_prefix=FLAGS.model_dir)