def model_fn()

in example_zoo/tensorflow/probability/vae/trainer/vae.py [0:0]


def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  if params["analytic_kl"] and params["mixture_components"] != 1:
    raise NotImplementedError(
        "Using `analytic_kl` is only supported when `mixture_components = 1` "
        "since there's no closed form otherwise.")

  encoder = make_encoder(params["activation"],
                         params["latent_size"],
                         params["base_depth"])
  decoder = make_decoder(params["activation"],
                         params["latent_size"],
                         IMAGE_SHAPE,
                         params["base_depth"])
  latent_prior = make_mixture_prior(params["latent_size"],
                                    params["mixture_components"])

  image_tile_summary(
      "input", tf.cast(features, dtype=tf.float32), rows=1, cols=16)

  approx_posterior = encoder(features)
  approx_posterior_sample = approx_posterior.sample(params["n_samples"])
  decoder_likelihood = decoder(approx_posterior_sample)
  image_tile_summary(
      "recon/sample",
      tf.cast(decoder_likelihood.sample()[:3, :16], dtype=tf.float32),
      rows=3,
      cols=16)
  image_tile_summary(
      "recon/mean",
      decoder_likelihood.mean()[:3, :16],
      rows=3,
      cols=16)

  # `distortion` is just the negative log likelihood.
  distortion = -decoder_likelihood.log_prob(features)
  avg_distortion = tf.reduce_mean(input_tensor=distortion)
  tf.compat.v1.summary.scalar("distortion", avg_distortion)

  if params["analytic_kl"]:
    rate = tfd.kl_divergence(approx_posterior, latent_prior)
  else:
    rate = (approx_posterior.log_prob(approx_posterior_sample)
            - latent_prior.log_prob(approx_posterior_sample))
  avg_rate = tf.reduce_mean(input_tensor=rate)
  tf.compat.v1.summary.scalar("rate", avg_rate)

  elbo_local = -(rate + distortion)

  elbo = tf.reduce_mean(input_tensor=elbo_local)
  loss = -elbo
  tf.compat.v1.summary.scalar("elbo", elbo)

  importance_weighted_elbo = tf.reduce_mean(
      input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) -
      tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32)))
  tf.compat.v1.summary.scalar("elbo/importance_weighted",
                              importance_weighted_elbo)

  # Decode samples from the prior for visualization.
  random_image = decoder(latent_prior.sample(16))
  image_tile_summary(
      "random/sample",
      tf.cast(random_image.sample(), dtype=tf.float32),
      rows=4,
      cols=4)
  image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4)

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.compat.v1.train.get_or_create_global_step()
  learning_rate = tf.compat.v1.train.cosine_decay(
      params["learning_rate"], global_step, params["max_steps"])
  tf.compat.v1.summary.scalar("learning_rate", learning_rate)
  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
  train_op = optimizer.minimize(loss, global_step=global_step)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo":
              tf.compat.v1.metrics.mean(elbo),
          "elbo/importance_weighted":
              tf.compat.v1.metrics.mean(importance_weighted_elbo),
          "rate":
              tf.compat.v1.metrics.mean(avg_rate),
          "distortion":
              tf.compat.v1.metrics.mean(avg_distortion),
      },
  )