in example_zoo/tensorflow/probability/latent_dirichlet_allocation_distributions/trainer/latent_dirichlet_allocation_distributions.py [0:0]
def main(argv):
del argv # unused
params = FLAGS.flag_values_dict()
params["layer_sizes"] = [int(units) for units in params["layer_sizes"]]
params["activation"] = getattr(tf.nn, params["activation"])
if FLAGS.delete_existing and tf.io.gfile.exists(FLAGS.model_dir):
tf.compat.v1.logging.warn("Deleting old log directory at {}".format(
FLAGS.model_dir))
tf.io.gfile.rmtree(FLAGS.model_dir)
tf.io.gfile.makedirs(FLAGS.model_dir)
if FLAGS.fake_data:
train_input_fn, eval_input_fn, vocabulary = build_fake_input_fns(
FLAGS.batch_size)
else:
train_input_fn, eval_input_fn, vocabulary = build_input_fns(
FLAGS.data_dir, FLAGS.batch_size)
params["vocabulary"] = vocabulary
estimator = tf.estimator.Estimator(
model_fn,
params=params,
config=tf.estimator.RunConfig(
model_dir=FLAGS.model_dir,
save_checkpoints_steps=FLAGS.viz_steps,
),
)
for _ in range(FLAGS.max_steps // FLAGS.viz_steps):
estimator.train(train_input_fn, steps=FLAGS.viz_steps)
eval_results = estimator.evaluate(eval_input_fn)
# Print the evaluation results. The keys are strings specified in
# eval_metric_ops, and the values are NumPy scalars/arrays.
for key, value in eval_results.items():
print(key)
if key == "topics":
# Topics description is a np.array which prints better row-by-row.
for s in value:
print(s)
else:
print(str(value))
print("")
print("")