def model_fn()

in example_zoo/tensorflow/probability/latent_dirichlet_allocation_distributions/trainer/latent_dirichlet_allocation_distributions.py [0:0]


def model_fn(features, labels, mode, params, config):
  """Build the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.
  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  encoder = make_encoder(params["activation"],
                         params["num_topics"],
                         params["layer_sizes"])
  decoder, topics_words = make_decoder(params["num_topics"],
                                       features.shape[1])
  prior, prior_variables = make_prior(params["num_topics"],
                                      params["prior_initial_value"])

  topics_prior = prior()
  alpha = topics_prior.concentration

  topics_posterior = encoder(features)
  topics = topics_posterior.sample()
  random_reconstruction = decoder(topics)

  reconstruction = random_reconstruction.log_prob(features)
  tf.compat.v1.summary.scalar("reconstruction",
                              tf.reduce_mean(input_tensor=reconstruction))

  # Compute the KL-divergence between two Dirichlets analytically.
  # The sampled KL does not work well for "sparse" distributions
  # (see Appendix D of [2]).
  kl = tfd.kl_divergence(topics_posterior, topics_prior)
  tf.compat.v1.summary.scalar("kl", tf.reduce_mean(input_tensor=kl))

  # Ensure that the KL is non-negative (up to a very small slack).
  # Negative KL can happen due to numerical instability.
  with tf.control_dependencies(
      [tf.compat.v1.assert_greater(kl, -1e-3, message="kl")]):
    kl = tf.identity(kl)

  elbo = reconstruction - kl
  avg_elbo = tf.reduce_mean(input_tensor=elbo)
  tf.compat.v1.summary.scalar("elbo", avg_elbo)
  loss = -avg_elbo

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.compat.v1.train.get_or_create_global_step()
  optimizer = tf.compat.v1.train.AdamOptimizer(params["learning_rate"])

  # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
  # For the first prior_burn_in_steps steps they are fixed, and then trained
  # jointly with the other parameters.
  grads_and_vars = optimizer.compute_gradients(loss)
  grads_and_vars_except_prior = [
      x for x in grads_and_vars if x[1] not in prior_variables]

  def train_op_except_prior():
    return optimizer.apply_gradients(
        grads_and_vars_except_prior,
        global_step=global_step)

  def train_op_all():
    return optimizer.apply_gradients(
        grads_and_vars,
        global_step=global_step)

  train_op = tf.cond(
      pred=global_step < params["prior_burn_in_steps"],
      true_fn=train_op_except_prior,
      false_fn=train_op_all)

  # The perplexity is an exponent of the average negative ELBO per word.
  words_per_document = tf.reduce_sum(input_tensor=features, axis=1)
  log_perplexity = -elbo / words_per_document
  tf.compat.v1.summary.scalar(
      "perplexity", tf.exp(tf.reduce_mean(input_tensor=log_perplexity)))
  (log_perplexity_tensor,
   log_perplexity_update) = tf.compat.v1.metrics.mean(log_perplexity)
  perplexity_tensor = tf.exp(log_perplexity_tensor)

  # Obtain the topics summary. Implemented as a py_func for simplicity.
  topics = tf.compat.v1.py_func(
      functools.partial(get_topics_strings, vocabulary=params["vocabulary"]),
      [topics_words, alpha],
      tf.string,
      stateful=False)
  tf.compat.v1.summary.text("topics", topics)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo": tf.compat.v1.metrics.mean(elbo),
          "reconstruction": tf.compat.v1.metrics.mean(reconstruction),
          "kl": tf.compat.v1.metrics.mean(kl),
          "perplexity": (perplexity_tensor, log_perplexity_update),
          "topics": (topics, tf.no_op()),
      },
  )