example_zoo/tensorflow/probability/vae/trainer/vae.py (295 lines of code) (raw):

# Copyright 2018 The TensorFlow Probability Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Trains a variational auto-encoder (VAE) on binarized MNIST. The VAE defines a generative model in which a latent code `Z` is sampled from a prior `p(Z)`, then used to generate an observation `X` by way of a decoder `p(X|Z)`. The full reconstruction follows ```none X ~ p(X) # A random image from some dataset. Z ~ q(Z | X) # A random encoding of the original image ("encoder"). Xhat ~ p(Xhat | Z) # A random reconstruction of the original image # ("decoder"). ``` To fit the VAE, we assume an approximate representation of the posterior in the form of an encoder `q(Z|X)`. We minimize the KL divergence between `q(Z|X)` and the true posterior `p(Z|X)`: this is equivalent to maximizing the evidence lower bound (ELBO), ```none -log p(x) = -log int dz p(x|z) p(z) = -log int dz q(z|x) p(x|z) p(z) / q(z|x) <= int dz q(z|x) (-log[ p(x|z) p(z) / q(z|x) ]) # Jensen's Inequality =: KL[q(Z|x) || p(x|Z)p(Z)] = -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)] ``` -or- ```none -log p(x) = KL[q(Z|x) || p(x|Z)p(Z)] - KL[q(Z|x) || p(Z|x)] <= KL[q(Z|x) || p(x|Z)p(Z) # Positivity of KL = -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)] ``` The `-E_{Z~q(Z|x)}[log p(x|Z)]` term is an expected reconstruction loss and `KL[q(Z|x) || p(Z)]` is a kind of distributional regularizer. See [Kingma and Welling (2014)][1] for more details. This script supports both a (learned) mixture of Gaussians prior as well as a fixed standard normal prior. You can enable the fixed standard normal prior by setting `mixture_components` to 1. Note that fixing the parameters of the prior (as opposed to fitting them with the rest of the model) incurs no loss in generality when using only a single Gaussian. The reasoning for this is two-fold: * On the generative side, the parameters from the prior can simply be absorbed into the first linear layer of the generative net. If `z ~ N(mu, Sigma)` and the first layer of the generative net is given by `x = Wz + b`, this can be rewritten, s ~ N(0, I) x = Wz + b = W (As + mu) + b = (WA) s + (W mu + b) where Sigma has been decomposed into A A^T = Sigma. In other words, the log likelihood of the model (E_{Z~q(Z|x)}[log p(x|Z)]) is independent of whether or not we learn mu and Sigma. * On the inference side, we can adjust any posterior approximation q(z | x) ~ N(mu[q], Sigma[q]), with new_mu[p] := 0 new_Sigma[p] := eye(d) new_mu[q] := inv(chol(Sigma[p])) @ (mu[p] - mu[q]) new_Sigma[q] := inv(Sigma[q]) @ Sigma[p] A bit of algebra on the KL divergence term `KL[q(Z|x) || p(Z)]` reveals that it is also invariant to the prior parameters as long as Sigma[p] and Sigma[q] are invertible. This script also supports using the analytic KL (KL[q(Z|x) || p(Z)]) with the `analytic_kl` flag. Using the analytic KL is only supported when `mixture_components` is set to 1 since otherwise no analytic form is known. Here we also compute tighter bounds, the IWAE [Burda et. al. (2015)][2]. These as well as image summaries can be seen in Tensorboard. For help using Tensorboard see https://www.tensorflow.org/guide/summaries_and_tensorboard which can be run with `python -m tensorboard.main --logdir=MODEL_DIR` #### References [1]: Diederik Kingma and Max Welling. Auto-Encoding Variational Bayes. In _International Conference on Learning Representations_, 2014. https://arxiv.org/abs/1312.6114 [2]: Yuri Burda, Roger Grosse, Ruslan Salakhutdinov. Importance Weighted Autoencoders. In _International Conference on Learning Representations_, 2015. https://arxiv.org/abs/1509.00519 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import os # Dependency imports from absl import flags flags.DEFINE_string(name="job-dir", default="/tmp", help="AI Platform Training passes this to the training script.") import numpy as np from six.moves import urllib import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions IMAGE_SHAPE = [28, 28, 1] flags.DEFINE_float( "learning_rate", default=0.001, help="Initial learning rate.") flags.DEFINE_integer( "max_steps", default=5001, help="Number of training steps to run.") flags.DEFINE_integer( "latent_size", default=16, help="Number of dimensions in the latent code (z).") flags.DEFINE_integer("base_depth", default=32, help="Base depth for layers.") flags.DEFINE_string( "activation", default="leaky_relu", help="Activation function for all hidden layers.") flags.DEFINE_integer( "batch_size", default=32, help="Batch size.") flags.DEFINE_integer( "n_samples", default=16, help="Number of samples to use in encoding.") flags.DEFINE_integer( "mixture_components", default=100, help="Number of mixture components to use in the prior. Each component is " "a diagonal normal distribution. The parameters of the components are " "intialized randomly, and then learned along with the rest of the " "parameters. If `analytic_kl` is True, `mixture_components` must be " "set to `1`.") flags.DEFINE_bool( "analytic_kl", default=False, help="Whether or not to use the analytic version of the KL. When set to " "False the E_{Z~q(Z|X)}[log p(Z)p(X|Z) - log q(Z|X)] form of the ELBO " "will be used. Otherwise the -KL(q(Z|X) || p(Z)) + " "E_{Z~q(Z|X)}[log p(X|Z)] form will be used. If analytic_kl is True, " "then you must also specify `mixture_components=1`.") flags.DEFINE_string( "data_dir", default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/data"), help="Directory where data is stored (if using real data).") flags.DEFINE_string( "model_dir", default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/"), help="Directory to put the model's fit.") flags.DEFINE_integer( "viz_steps", default=500, help="Frequency at which to save visualizations.") flags.DEFINE_bool( "fake_data", default=False, help="If true, uses fake data instead of MNIST.") flags.DEFINE_bool( "delete_existing", default=False, help="If true, deletes existing `model_dir` directory.") FLAGS = flags.FLAGS def _softplus_inverse(x): """Helper which computes the function inverse of `tf.nn.softplus`.""" return tf.math.log(tf.math.expm1(x)) def make_encoder(activation, latent_size, base_depth): """Creates the encoder function. Args: activation: Activation function in hidden layers. latent_size: The dimensionality of the encoding. base_depth: The lowest depth for a layer. Returns: encoder: A `callable` mapping a `Tensor` of images to a `tfd.Distribution` instance over encodings. """ conv = functools.partial( tf.keras.layers.Conv2D, padding="SAME", activation=activation) encoder_net = tf.keras.Sequential([ conv(base_depth, 5, 1), conv(base_depth, 5, 2), conv(2 * base_depth, 5, 1), conv(2 * base_depth, 5, 2), conv(4 * latent_size, 7, padding="VALID"), tf.keras.layers.Flatten(), tf.keras.layers.Dense(2 * latent_size, activation=None), ]) def encoder(images): images = 2 * tf.cast(images, dtype=tf.float32) - 1 net = encoder_net(images) return tfd.MultivariateNormalDiag( loc=net[..., :latent_size], scale_diag=tf.nn.softplus(net[..., latent_size:] + _softplus_inverse(1.0)), name="code") return encoder def make_decoder(activation, latent_size, output_shape, base_depth): """Creates the decoder function. Args: activation: Activation function in hidden layers. latent_size: Dimensionality of the encoding. output_shape: The output image shape. base_depth: Smallest depth for a layer. Returns: decoder: A `callable` mapping a `Tensor` of encodings to a `tfd.Distribution` instance over images. """ deconv = functools.partial( tf.keras.layers.Conv2DTranspose, padding="SAME", activation=activation) conv = functools.partial( tf.keras.layers.Conv2D, padding="SAME", activation=activation) decoder_net = tf.keras.Sequential([ deconv(2 * base_depth, 7, padding="VALID"), deconv(2 * base_depth, 5), deconv(2 * base_depth, 5, 2), deconv(base_depth, 5), deconv(base_depth, 5, 2), deconv(base_depth, 5), conv(output_shape[-1], 5, activation=None), ]) def decoder(codes): original_shape = tf.shape(input=codes) # Collapse the sample and batch dimension and convert to rank-4 tensor for # use with a convolutional decoder network. codes = tf.reshape(codes, (-1, 1, 1, latent_size)) logits = decoder_net(codes) logits = tf.reshape( logits, shape=tf.concat([original_shape[:-1], output_shape], axis=0)) return tfd.Independent(tfd.Bernoulli(logits=logits), reinterpreted_batch_ndims=len(output_shape), name="image") return decoder def make_mixture_prior(latent_size, mixture_components): """Creates the mixture of Gaussians prior distribution. Args: latent_size: The dimensionality of the latent representation. mixture_components: Number of elements of the mixture. Returns: random_prior: A `tfd.Distribution` instance representing the distribution over encodings in the absence of any evidence. """ if mixture_components == 1: # See the module docstring for why we don't learn the parameters here. return tfd.MultivariateNormalDiag( loc=tf.zeros([latent_size]), scale_identity_multiplier=1.0) loc = tf.compat.v1.get_variable( name="loc", shape=[mixture_components, latent_size]) raw_scale_diag = tf.compat.v1.get_variable( name="raw_scale_diag", shape=[mixture_components, latent_size]) mixture_logits = tf.compat.v1.get_variable( name="mixture_logits", shape=[mixture_components]) return tfd.MixtureSameFamily( components_distribution=tfd.MultivariateNormalDiag( loc=loc, scale_diag=tf.nn.softplus(raw_scale_diag)), mixture_distribution=tfd.Categorical(logits=mixture_logits), name="prior") def pack_images(images, rows, cols): """Helper utility to make a field of images.""" shape = tf.shape(input=images) width = shape[-3] height = shape[-2] depth = shape[-1] images = tf.reshape(images, (-1, width, height, depth)) batch = tf.shape(input=images)[0] rows = tf.minimum(rows, batch) cols = tf.minimum(batch // rows, cols) images = images[:rows * cols] images = tf.reshape(images, (rows, cols, width, height, depth)) images = tf.transpose(a=images, perm=[0, 2, 1, 3, 4]) images = tf.reshape(images, [1, rows * width, cols * height, depth]) return images def image_tile_summary(name, tensor, rows=8, cols=8): tf.compat.v1.summary.image( name, pack_images(tensor, rows, cols), max_outputs=1) def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. Arguments: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. params: Some hyperparameters as a dictionary. config: The RunConfig, unused here. Returns: EstimatorSpec: A tf.estimator.EstimatorSpec instance. """ del labels, config if params["analytic_kl"] and params["mixture_components"] != 1: raise NotImplementedError( "Using `analytic_kl` is only supported when `mixture_components = 1` " "since there's no closed form otherwise.") encoder = make_encoder(params["activation"], params["latent_size"], params["base_depth"]) decoder = make_decoder(params["activation"], params["latent_size"], IMAGE_SHAPE, params["base_depth"]) latent_prior = make_mixture_prior(params["latent_size"], params["mixture_components"]) image_tile_summary( "input", tf.cast(features, dtype=tf.float32), rows=1, cols=16) approx_posterior = encoder(features) approx_posterior_sample = approx_posterior.sample(params["n_samples"]) decoder_likelihood = decoder(approx_posterior_sample) image_tile_summary( "recon/sample", tf.cast(decoder_likelihood.sample()[:3, :16], dtype=tf.float32), rows=3, cols=16) image_tile_summary( "recon/mean", decoder_likelihood.mean()[:3, :16], rows=3, cols=16) # `distortion` is just the negative log likelihood. distortion = -decoder_likelihood.log_prob(features) avg_distortion = tf.reduce_mean(input_tensor=distortion) tf.compat.v1.summary.scalar("distortion", avg_distortion) if params["analytic_kl"]: rate = tfd.kl_divergence(approx_posterior, latent_prior) else: rate = (approx_posterior.log_prob(approx_posterior_sample) - latent_prior.log_prob(approx_posterior_sample)) avg_rate = tf.reduce_mean(input_tensor=rate) tf.compat.v1.summary.scalar("rate", avg_rate) elbo_local = -(rate + distortion) elbo = tf.reduce_mean(input_tensor=elbo_local) loss = -elbo tf.compat.v1.summary.scalar("elbo", elbo) importance_weighted_elbo = tf.reduce_mean( input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) - tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32))) tf.compat.v1.summary.scalar("elbo/importance_weighted", importance_weighted_elbo) # Decode samples from the prior for visualization. random_image = decoder(latent_prior.sample(16)) image_tile_summary( "random/sample", tf.cast(random_image.sample(), dtype=tf.float32), rows=4, cols=4) image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4) # Perform variational inference by minimizing the -ELBO. global_step = tf.compat.v1.train.get_or_create_global_step() learning_rate = tf.compat.v1.train.cosine_decay( params["learning_rate"], global_step, params["max_steps"]) tf.compat.v1.summary.scalar("learning_rate", learning_rate) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops={ "elbo": tf.compat.v1.metrics.mean(elbo), "elbo/importance_weighted": tf.compat.v1.metrics.mean(importance_weighted_elbo), "rate": tf.compat.v1.metrics.mean(avg_rate), "distortion": tf.compat.v1.metrics.mean(avg_distortion), }, ) ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/" FILE_TEMPLATE = "binarized_mnist_{split}.amat" def download(directory, filename): """Downloads a file.""" filepath = os.path.join(directory, filename) if tf.io.gfile.exists(filepath): return filepath if not tf.io.gfile.exists(directory): tf.io.gfile.makedirs(directory) url = os.path.join(ROOT_PATH, filename) print("Downloading %s to %s" % (url, filepath)) urllib.request.urlretrieve(url, filepath) return filepath def static_mnist_dataset(directory, split_name): """Returns binary static MNIST tf.data.Dataset.""" amat_file = download(directory, FILE_TEMPLATE.format(split=split_name)) dataset = tf.data.TextLineDataset(amat_file) str_to_arr = lambda string: np.array([c == b"1" for c in string.split()]) def _parser(s): booltensor = tf.compat.v1.py_func(str_to_arr, [s], tf.bool) reshaped = tf.reshape(booltensor, [28, 28, 1]) return tf.cast(reshaped, dtype=tf.float32), tf.constant(0, tf.int32) return dataset.map(_parser) def build_fake_input_fns(batch_size): """Builds fake MNIST-style data for unit testing.""" random_sample = np.random.rand(batch_size, *IMAGE_SHAPE).astype("float32") def train_input_fn(): dataset = tf.data.Dataset.from_tensor_slices( random_sample).map(lambda row: (row, 0)).batch(batch_size).repeat() return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() def eval_input_fn(): dataset = tf.data.Dataset.from_tensor_slices( random_sample).map(lambda row: (row, 0)).batch(batch_size) return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() return train_input_fn, eval_input_fn def build_input_fns(data_dir, batch_size): """Builds an Iterator switching between train and heldout data.""" # Build an iterator over training batches. def train_input_fn(): dataset = static_mnist_dataset(data_dir, "train") dataset = dataset.shuffle(50000).repeat().batch(batch_size) return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # Build an iterator over the heldout set. def eval_input_fn(): eval_dataset = static_mnist_dataset(data_dir, "valid") eval_dataset = eval_dataset.batch(batch_size) return tf.compat.v1.data.make_one_shot_iterator(eval_dataset).get_next() return train_input_fn, eval_input_fn def main(argv): del argv # unused params = FLAGS.flag_values_dict() params["activation"] = getattr(tf.nn, params["activation"]) if FLAGS.delete_existing and tf.io.gfile.exists(FLAGS.model_dir): tf.compat.v1.logging.warn("Deleting old log directory at {}".format( FLAGS.model_dir)) tf.io.gfile.rmtree(FLAGS.model_dir) tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.fake_data: train_input_fn, eval_input_fn = build_fake_input_fns(FLAGS.batch_size) else: train_input_fn, eval_input_fn = build_input_fns(FLAGS.data_dir, FLAGS.batch_size) estimator = tf.estimator.Estimator( model_fn, params=params, config=tf.estimator.RunConfig( model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.viz_steps, ), ) for _ in range(FLAGS.max_steps // FLAGS.viz_steps): estimator.train(train_input_fn, steps=FLAGS.viz_steps) eval_results = estimator.evaluate(eval_input_fn) print("Evaluation_results:\n\t%s\n" % eval_results) if __name__ == "__main__": tf.compat.v1.app.run()