in example_zoo/tensorflow/probability/latent_dirichlet_allocation_edward2/trainer/latent_dirichlet_allocation_edward2.py [0:0]
def model_fn(features, labels, mode, params, config):
"""Builds the model function for use in an Estimator.
Arguments:
features: The input features for the Estimator.
labels: The labels, unused here.
mode: Signifies whether it is train or test or predict.
params: Some hyperparameters as a dictionary.
config: The RunConfig, unused here.
Returns:
EstimatorSpec: A tf.estimator.EstimatorSpec instance.
"""
del labels, config
# Set up the model's learnable parameters.
logit_concentration = tf.compat.v1.get_variable(
"logit_concentration",
shape=[1, params["num_topics"]],
initializer=tf.compat.v1.initializers.constant(
_softplus_inverse(params["prior_initial_value"])))
concentration = _clip_dirichlet_parameters(
tf.nn.softplus(logit_concentration))
num_words = features.shape[1]
topics_words_logits = tf.compat.v1.get_variable(
"topics_words_logits",
shape=[params["num_topics"], num_words],
initializer=tf.compat.v1.glorot_normal_initializer())
topics_words = tf.nn.softmax(topics_words_logits, axis=-1)
# Compute expected log-likelihood. First, sample from the variational
# distribution; second, compute the log-likelihood given the sample.
lda_variational = make_lda_variational(
params["activation"],
params["num_topics"],
params["layer_sizes"])
with ed.tape() as variational_tape:
_ = lda_variational(features)
with ed.tape() as model_tape:
with ed.interception(
make_value_setter(topics=variational_tape["topics_posterior"])):
posterior_predictive = latent_dirichlet_allocation(concentration,
topics_words)
log_likelihood = posterior_predictive.distribution.log_prob(features)
tf.compat.v1.summary.scalar("log_likelihood",
tf.reduce_mean(input_tensor=log_likelihood))
# Compute the KL-divergence between two Dirichlets analytically.
# The sampled KL does not work well for "sparse" distributions
# (see Appendix D of [2]).
kl = variational_tape["topics_posterior"].distribution.kl_divergence(
model_tape["topics"].distribution)
tf.compat.v1.summary.scalar("kl", tf.reduce_mean(input_tensor=kl))
# Ensure that the KL is non-negative (up to a very small slack).
# Negative KL can happen due to numerical instability.
with tf.control_dependencies(
[tf.compat.v1.assert_greater(kl, -1e-3, message="kl")]):
kl = tf.identity(kl)
elbo = log_likelihood - kl
avg_elbo = tf.reduce_mean(input_tensor=elbo)
tf.compat.v1.summary.scalar("elbo", avg_elbo)
loss = -avg_elbo
# Perform variational inference by minimizing the -ELBO.
global_step = tf.compat.v1.train.get_or_create_global_step()
optimizer = tf.compat.v1.train.AdamOptimizer(params["learning_rate"])
# This implements the "burn-in" for prior parameters (see Appendix D of [2]).
# For the first prior_burn_in_steps steps they are fixed, and then trained
# jointly with the other parameters.
grads_and_vars = optimizer.compute_gradients(loss)
grads_and_vars_except_prior = [
x for x in grads_and_vars if x[1] != logit_concentration]
def train_op_except_prior():
return optimizer.apply_gradients(
grads_and_vars_except_prior,
global_step=global_step)
def train_op_all():
return optimizer.apply_gradients(
grads_and_vars,
global_step=global_step)
train_op = tf.cond(
pred=global_step < params["prior_burn_in_steps"],
true_fn=train_op_except_prior,
false_fn=train_op_all)
# The perplexity is an exponent of the average negative ELBO per word.
words_per_document = tf.reduce_sum(input_tensor=features, axis=1)
log_perplexity = -elbo / words_per_document
tf.compat.v1.summary.scalar(
"perplexity", tf.exp(tf.reduce_mean(input_tensor=log_perplexity)))
(log_perplexity_tensor,
log_perplexity_update) = tf.compat.v1.metrics.mean(log_perplexity)
perplexity_tensor = tf.exp(log_perplexity_tensor)
# Obtain the topics summary. Implemented as a py_func for simplicity.
topics = tf.compat.v1.py_func(
functools.partial(get_topics_strings, vocabulary=params["vocabulary"]),
[topics_words, concentration],
tf.string,
stateful=False)
tf.compat.v1.summary.text("topics", topics)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
eval_metric_ops={
"elbo": tf.compat.v1.metrics.mean(elbo),
"log_likelihood": tf.compat.v1.metrics.mean(log_likelihood),
"kl": tf.compat.v1.metrics.mean(kl),
"perplexity": (perplexity_tensor, log_perplexity_update),
"topics": (topics, tf.no_op()),
},
)