def model_fn()

in example_zoo/tensorflow/probability/latent_dirichlet_allocation_edward2/trainer/latent_dirichlet_allocation_edward2.py [0:0]


def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an Estimator.

  Arguments:
    features: The input features for the Estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  # Set up the model's learnable parameters.
  logit_concentration = tf.compat.v1.get_variable(
      "logit_concentration",
      shape=[1, params["num_topics"]],
      initializer=tf.compat.v1.initializers.constant(
          _softplus_inverse(params["prior_initial_value"])))
  concentration = _clip_dirichlet_parameters(
      tf.nn.softplus(logit_concentration))

  num_words = features.shape[1]
  topics_words_logits = tf.compat.v1.get_variable(
      "topics_words_logits",
      shape=[params["num_topics"], num_words],
      initializer=tf.compat.v1.glorot_normal_initializer())
  topics_words = tf.nn.softmax(topics_words_logits, axis=-1)

  # Compute expected log-likelihood. First, sample from the variational
  # distribution; second, compute the log-likelihood given the sample.
  lda_variational = make_lda_variational(
      params["activation"],
      params["num_topics"],
      params["layer_sizes"])
  with ed.tape() as variational_tape:
    _ = lda_variational(features)

  with ed.tape() as model_tape:
    with ed.interception(
        make_value_setter(topics=variational_tape["topics_posterior"])):
      posterior_predictive = latent_dirichlet_allocation(concentration,
                                                         topics_words)

  log_likelihood = posterior_predictive.distribution.log_prob(features)
  tf.compat.v1.summary.scalar("log_likelihood",
                              tf.reduce_mean(input_tensor=log_likelihood))

  # Compute the KL-divergence between two Dirichlets analytically.
  # The sampled KL does not work well for "sparse" distributions
  # (see Appendix D of [2]).
  kl = variational_tape["topics_posterior"].distribution.kl_divergence(
      model_tape["topics"].distribution)
  tf.compat.v1.summary.scalar("kl", tf.reduce_mean(input_tensor=kl))

  # Ensure that the KL is non-negative (up to a very small slack).
  # Negative KL can happen due to numerical instability.
  with tf.control_dependencies(
      [tf.compat.v1.assert_greater(kl, -1e-3, message="kl")]):
    kl = tf.identity(kl)

  elbo = log_likelihood - kl
  avg_elbo = tf.reduce_mean(input_tensor=elbo)
  tf.compat.v1.summary.scalar("elbo", avg_elbo)
  loss = -avg_elbo

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.compat.v1.train.get_or_create_global_step()
  optimizer = tf.compat.v1.train.AdamOptimizer(params["learning_rate"])

  # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
  # For the first prior_burn_in_steps steps they are fixed, and then trained
  # jointly with the other parameters.
  grads_and_vars = optimizer.compute_gradients(loss)
  grads_and_vars_except_prior = [
      x for x in grads_and_vars if x[1] != logit_concentration]

  def train_op_except_prior():
    return optimizer.apply_gradients(
        grads_and_vars_except_prior,
        global_step=global_step)

  def train_op_all():
    return optimizer.apply_gradients(
        grads_and_vars,
        global_step=global_step)

  train_op = tf.cond(
      pred=global_step < params["prior_burn_in_steps"],
      true_fn=train_op_except_prior,
      false_fn=train_op_all)

  # The perplexity is an exponent of the average negative ELBO per word.
  words_per_document = tf.reduce_sum(input_tensor=features, axis=1)
  log_perplexity = -elbo / words_per_document
  tf.compat.v1.summary.scalar(
      "perplexity", tf.exp(tf.reduce_mean(input_tensor=log_perplexity)))
  (log_perplexity_tensor,
   log_perplexity_update) = tf.compat.v1.metrics.mean(log_perplexity)
  perplexity_tensor = tf.exp(log_perplexity_tensor)

  # Obtain the topics summary. Implemented as a py_func for simplicity.
  topics = tf.compat.v1.py_func(
      functools.partial(get_topics_strings, vocabulary=params["vocabulary"]),
      [topics_words, concentration],
      tf.string,
      stateful=False)
  tf.compat.v1.summary.text("topics", topics)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo": tf.compat.v1.metrics.mean(elbo),
          "log_likelihood": tf.compat.v1.metrics.mean(log_likelihood),
          "kl": tf.compat.v1.metrics.mean(kl),
          "perplexity": (perplexity_tensor, log_perplexity_update),
          "topics": (topics, tf.no_op()),
      },
  )