examples/vae/vae.py (122 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import numpy as np
import torch
import torch.nn as nn
import visdom
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam
from utils.mnist_cached import MNISTCached as MNIST
from utils.mnist_cached import setup_data_loaders
from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples
# define the PyTorch module that parameterizes the
# diagonal gaussian distribution q(z|x)
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# setup the three linear transformations used
self.fc1 = nn.Linear(784, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim)
self.fc22 = nn.Linear(hidden_dim, z_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, x):
# define the forward computation on the image x
# first shape the mini-batch to have pixels in the rightmost dimension
x = x.reshape(-1, 784)
# then compute the hidden units
hidden = self.softplus(self.fc1(x))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
# define the PyTorch module that parameterizes the
# observation likelihood p(x|z)
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# setup the two linear transformations used
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, 784)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, z):
# define the forward computation on the latent z
# first compute the hidden units
hidden = self.softplus(self.fc1(z))
# return the parameter for the output Bernoulli
# each is of size batch_size x 784
loc_img = torch.sigmoid(self.fc21(hidden))
return loc_img
# define a PyTorch module for the VAE
class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super().__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
# return the loc so we can visualize it later
return loc_img
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# define a helper function for reconstructing images
def reconstruct_img(self, x):
# encode image x
z_loc, z_scale = self.encoder(x)
# sample in latent space
z = dist.Normal(z_loc, z_scale).sample()
# decode the image (note we don't sample in image space)
loc_img = self.decoder(z)
return loc_img
def main(args):
# clear param store
pyro.clear_param_store()
# setup MNIST data loaders
# train_loader, test_loader
train_loader, test_loader = setup_data_loaders(MNIST, use_cuda=args.cuda, batch_size=256)
# setup the VAE
vae = VAE(use_cuda=args.cuda)
# setup the optimizer
adam_args = {"lr": args.learning_rate}
optimizer = Adam(adam_args)
# setup the inference algorithm
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)
# setup visdom for visualization
if args.visdom_flag:
vis = visdom.Visdom()
train_elbo = []
test_elbo = []
# training loop
for epoch in range(args.num_epochs):
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch x returned
# by the data loader
for x, _ in train_loader:
# if on GPU put mini-batch into CUDA memory
if args.cuda:
x = x.cuda()
# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x)
# report training diagnostics
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss / normalizer_train
train_elbo.append(total_epoch_loss_train)
print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
if epoch % args.test_frequency == 0:
# initialize loss accumulator
test_loss = 0.
# compute the loss over the entire test set
for i, (x, _) in enumerate(test_loader):
# if on GPU put mini-batch into CUDA memory
if args.cuda:
x = x.cuda()
# compute ELBO estimate and accumulate loss
test_loss += svi.evaluate_loss(x)
# pick three random test images from the first mini-batch and
# visualize how well we're reconstructing them
if i == 0:
if args.visdom_flag:
plot_vae_samples(vae, vis)
reco_indices = np.random.randint(0, x.shape[0], 3)
for index in reco_indices:
test_img = x[index, :]
reco_img = vae.reconstruct_img(test_img)
vis.image(test_img.reshape(28, 28).detach().cpu().numpy(),
opts={'caption': 'test image'})
vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(),
opts={'caption': 'reconstructed image'})
# report test diagnostics
normalizer_test = len(test_loader.dataset)
total_epoch_loss_test = test_loss / normalizer_test
test_elbo.append(total_epoch_loss_test)
print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
if epoch == args.tsne_iter:
mnist_test_tsne(vae=vae, test_loader=test_loader)
plot_llk(np.array(train_elbo), np.array(test_elbo))
return vae
if __name__ == '__main__':
assert pyro.__version__.startswith('1.4.0')
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=101, type=int, help='number of training epochs')
parser.add_argument('-tf', '--test-frequency', default=5, type=int, help='how often we evaluate the test set')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
parser.add_argument('--cuda', action='store_true', default=False, help='whether to use cuda')
parser.add_argument('--jit', action='store_true', default=False, help='whether to use PyTorch jit')
parser.add_argument('-visdom', '--visdom_flag', action="store_true", help='Whether plotting in visdom is desired')
parser.add_argument('-i-tsne', '--tsne_iter', default=100, type=int, help='epoch when tsne visualization runs')
args = parser.parse_args()
model = main(args)