in example_zoo/tensorflow/probability/vae/trainer/vae.py [0:0]
def model_fn(features, labels, mode, params, config):
"""Builds the model function for use in an estimator.
Arguments:
features: The input features for the estimator.
labels: The labels, unused here.
mode: Signifies whether it is train or test or predict.
params: Some hyperparameters as a dictionary.
config: The RunConfig, unused here.
Returns:
EstimatorSpec: A tf.estimator.EstimatorSpec instance.
"""
del labels, config
if params["analytic_kl"] and params["mixture_components"] != 1:
raise NotImplementedError(
"Using `analytic_kl` is only supported when `mixture_components = 1` "
"since there's no closed form otherwise.")
encoder = make_encoder(params["activation"],
params["latent_size"],
params["base_depth"])
decoder = make_decoder(params["activation"],
params["latent_size"],
IMAGE_SHAPE,
params["base_depth"])
latent_prior = make_mixture_prior(params["latent_size"],
params["mixture_components"])
image_tile_summary(
"input", tf.cast(features, dtype=tf.float32), rows=1, cols=16)
approx_posterior = encoder(features)
approx_posterior_sample = approx_posterior.sample(params["n_samples"])
decoder_likelihood = decoder(approx_posterior_sample)
image_tile_summary(
"recon/sample",
tf.cast(decoder_likelihood.sample()[:3, :16], dtype=tf.float32),
rows=3,
cols=16)
image_tile_summary(
"recon/mean",
decoder_likelihood.mean()[:3, :16],
rows=3,
cols=16)
# `distortion` is just the negative log likelihood.
distortion = -decoder_likelihood.log_prob(features)
avg_distortion = tf.reduce_mean(input_tensor=distortion)
tf.compat.v1.summary.scalar("distortion", avg_distortion)
if params["analytic_kl"]:
rate = tfd.kl_divergence(approx_posterior, latent_prior)
else:
rate = (approx_posterior.log_prob(approx_posterior_sample)
- latent_prior.log_prob(approx_posterior_sample))
avg_rate = tf.reduce_mean(input_tensor=rate)
tf.compat.v1.summary.scalar("rate", avg_rate)
elbo_local = -(rate + distortion)
elbo = tf.reduce_mean(input_tensor=elbo_local)
loss = -elbo
tf.compat.v1.summary.scalar("elbo", elbo)
importance_weighted_elbo = tf.reduce_mean(
input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) -
tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32)))
tf.compat.v1.summary.scalar("elbo/importance_weighted",
importance_weighted_elbo)
# Decode samples from the prior for visualization.
random_image = decoder(latent_prior.sample(16))
image_tile_summary(
"random/sample",
tf.cast(random_image.sample(), dtype=tf.float32),
rows=4,
cols=4)
image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4)
# Perform variational inference by minimizing the -ELBO.
global_step = tf.compat.v1.train.get_or_create_global_step()
learning_rate = tf.compat.v1.train.cosine_decay(
params["learning_rate"], global_step, params["max_steps"])
tf.compat.v1.summary.scalar("learning_rate", learning_rate)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
eval_metric_ops={
"elbo":
tf.compat.v1.metrics.mean(elbo),
"elbo/importance_weighted":
tf.compat.v1.metrics.mean(importance_weighted_elbo),
"rate":
tf.compat.v1.metrics.mean(avg_rate),
"distortion":
tf.compat.v1.metrics.mean(avg_distortion),
},
)