examples/vae/ss_vae_M2.py (219 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import torch
import torch.nn as nn
from visdom import Visdom
import pyro
import pyro.distributions as dist
from pyro.contrib.examples.util import print_and_log
from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
from utils.custom_mlp import MLP, Exp
from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders
from utils.vae_plots import mnist_test_tsne_ssvae, plot_conditional_samples_ssvae
class SSVAE(nn.Module):
"""
This class encapsulates the parameters (neural networks) and models & guides needed to train a
semi-supervised variational auto-encoder on the MNIST image dataset
:param output_size: size of the tensor representing the class label (10 for MNIST since
we represent the class labels as a one-hot vector with 10 components)
:param input_size: size of the tensor representing the image (28*28 = 784 for our MNIST dataset
since we flatten the images and scale the pixels to be in [0,1])
:param z_dim: size of the tensor representing the latent random variable z
(handwriting style for our MNIST dataset)
:param hidden_layers: a tuple (or list) of MLP layers to be used in the neural networks
representing the parameters of the distributions in our model
:param use_cuda: use GPUs for faster training
:param aux_loss_multiplier: the multiplier to use with the auxiliary loss
"""
def __init__(self, output_size=10, input_size=784, z_dim=50, hidden_layers=(500,),
config_enum=None, use_cuda=False, aux_loss_multiplier=None):
super().__init__()
# initialize the class with all arguments provided to the constructor
self.output_size = output_size
self.input_size = input_size
self.z_dim = z_dim
self.hidden_layers = hidden_layers
self.allow_broadcast = config_enum == 'parallel'
self.use_cuda = use_cuda
self.aux_loss_multiplier = aux_loss_multiplier
# define and instantiate the neural networks representing
# the paramters of various distributions in the model
self.setup_networks()
def setup_networks(self):
z_dim = self.z_dim
hidden_sizes = self.hidden_layers
# define the neural networks used later in the model and the guide.
# these networks are MLPs (multi-layered perceptrons or simple feed-forward networks)
# where the provided activation parameter is used on every linear layer except
# for the output layer where we use the provided output_activation parameter
self.encoder_y = MLP([self.input_size] + hidden_sizes + [self.output_size],
activation=nn.Softplus,
output_activation=nn.Softmax,
allow_broadcast=self.allow_broadcast,
use_cuda=self.use_cuda)
# a split in the final layer's size is used for multiple outputs
# and potentially applying separate activation functions on them
# e.g. in this network the final output is of size [z_dim,z_dim]
# to produce loc and scale, and apply different activations [None,Exp] on them
self.encoder_z = MLP([self.input_size + self.output_size] +
hidden_sizes + [[z_dim, z_dim]],
activation=nn.Softplus,
output_activation=[None, Exp],
allow_broadcast=self.allow_broadcast,
use_cuda=self.use_cuda)
self.decoder = MLP([z_dim + self.output_size] +
hidden_sizes + [self.input_size],
activation=nn.Softplus,
output_activation=nn.Sigmoid,
allow_broadcast=self.allow_broadcast,
use_cuda=self.use_cuda)
# using GPUs for faster training of the networks
if self.use_cuda:
self.cuda()
def model(self, xs, ys=None):
"""
The model corresponds to the following generative process:
p(z) = normal(0,I) # handwriting style (latent)
p(y|x) = categorical(I/10.) # which digit (semi-supervised)
p(x|y,z) = bernoulli(loc(y,z)) # an image
loc is given by a neural network `decoder`
:param xs: a batch of scaled vectors of pixels from an image
:param ys: (optional) a batch of the class labels i.e.
the digit corresponding to the image(s)
:return: None
"""
# register this pytorch module and all of its sub-modules with pyro
pyro.module("ss_vae", self)
batch_size = xs.size(0)
options = dict(dtype=xs.dtype, device=xs.device)
with pyro.plate("data"):
# sample the handwriting style from the constant prior distribution
prior_loc = torch.zeros(batch_size, self.z_dim, **options)
prior_scale = torch.ones(batch_size, self.z_dim, **options)
zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))
# if the label y (which digit to write) is supervised, sample from the
# constant prior, otherwise, observe the value (i.e. score it against the constant prior)
alpha_prior = torch.ones(batch_size, self.output_size, **options) / (1.0 * self.output_size)
ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)
# finally, score the image (x) using the handwriting style (z) and
# the class label y (which digit to write) against the
# parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
# where `decoder` is a neural network
loc = self.decoder.forward([zs, ys])
pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
# return the loc so we can visualize it later
return loc
def guide(self, xs, ys=None):
"""
The guide corresponds to the following:
q(y|x) = categorical(alpha(x)) # infer digit from an image
q(z|x,y) = normal(loc(x,y),scale(x,y)) # infer handwriting style from an image and the digit
loc, scale are given by a neural network `encoder_z`
alpha is given by a neural network `encoder_y`
:param xs: a batch of scaled vectors of pixels from an image
:param ys: (optional) a batch of the class labels i.e.
the digit corresponding to the image(s)
:return: None
"""
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
with pyro.plate("data"):
# if the class label (the digit) is not supervised, sample
# (and score) the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
if ys is None:
alpha = self.encoder_y.forward(xs)
ys = pyro.sample("y", dist.OneHotCategorical(alpha))
# sample (and score) the latent handwriting-style with the variational
# distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
loc, scale = self.encoder_z.forward([xs, ys])
pyro.sample("z", dist.Normal(loc, scale).to_event(1))
def classifier(self, xs):
"""
classify an image (or a batch of images)
:param xs: a batch of scaled vectors of pixels from an image
:return: a batch of the corresponding class labels (as one-hots)
"""
# use the trained model q(y|x) = categorical(alpha(x))
# compute all class probabilities for the image(s)
alpha = self.encoder_y.forward(xs)
# get the index (digit) that corresponds to
# the maximum predicted class probability
res, ind = torch.topk(alpha, 1)
# convert the digit(s) to one-hot tensor(s)
ys = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
return ys
def model_classify(self, xs, ys=None):
"""
this model is used to add an auxiliary (supervised) loss as described in the
Kingma et al., "Semi-Supervised Learning with Deep Generative Models".
"""
# register all pytorch (sub)modules with pyro
pyro.module("ss_vae", self)
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
with pyro.plate("data"):
# this here is the extra term to yield an auxiliary loss that we do gradient descent on
if ys is not None:
alpha = self.encoder_y.forward(xs)
with pyro.poutine.scale(scale=self.aux_loss_multiplier):
pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def guide_classify(self, xs, ys=None):
"""
dummy guide function to accompany model_classify in inference
"""
pass
def run_inference_for_epoch(data_loaders, losses, periodic_interval_batches):
"""
runs the inference algorithm for an epoch
returns the values of all losses separately on supervised and unsupervised parts
"""
num_losses = len(losses)
# compute number of batches for an epoch
sup_batches = len(data_loaders["sup"])
unsup_batches = len(data_loaders["unsup"])
batches_per_epoch = sup_batches + unsup_batches
# initialize variables to store loss values
epoch_losses_sup = [0.] * num_losses
epoch_losses_unsup = [0.] * num_losses
# setup the iterators for training data loaders
sup_iter = iter(data_loaders["sup"])
unsup_iter = iter(data_loaders["unsup"])
# count the number of supervised batches seen in this epoch
ctr_sup = 0
for i in range(batches_per_epoch):
# whether this batch is supervised or not
is_supervised = (i % periodic_interval_batches == 1) and ctr_sup < sup_batches
# extract the corresponding batch
if is_supervised:
(xs, ys) = next(sup_iter)
ctr_sup += 1
else:
(xs, ys) = next(unsup_iter)
# run the inference for each loss with supervised or un-supervised
# data as arguments
for loss_id in range(num_losses):
if is_supervised:
new_loss = losses[loss_id].step(xs, ys)
epoch_losses_sup[loss_id] += new_loss
else:
new_loss = losses[loss_id].step(xs)
epoch_losses_unsup[loss_id] += new_loss
# return the values of all losses
return epoch_losses_sup, epoch_losses_unsup
def get_accuracy(data_loader, classifier_fn, batch_size):
"""
compute the accuracy over the supervised training set or the testing set
"""
predictions, actuals = [], []
# use the appropriate data loader
for (xs, ys) in data_loader:
# use classification function to compute all predictions for each batch
predictions.append(classifier_fn(xs))
actuals.append(ys)
# compute the number of accurate predictions
accurate_preds = 0
for pred, act in zip(predictions, actuals):
for i in range(pred.size(0)):
v = torch.sum(pred[i] == act[i])
accurate_preds += (v.item() == 10)
# calculate the accuracy between 0 and 1
accuracy = (accurate_preds * 1.0) / (len(predictions) * batch_size)
return accuracy
def visualize(ss_vae, viz, test_loader):
if viz:
plot_conditional_samples_ssvae(ss_vae, viz)
mnist_test_tsne_ssvae(ssvae=ss_vae, test_loader=test_loader)
def main(args):
"""
run inference for SS-VAE
:param args: arguments for SS-VAE
:return: None
"""
if args.seed is not None:
pyro.set_rng_seed(args.seed)
viz = None
if args.visualize:
viz = Visdom()
mkdir_p("./vae_results")
# batch_size: number of images (and labels) to be considered in a batch
ss_vae = SSVAE(z_dim=args.z_dim,
hidden_layers=args.hidden_layers,
use_cuda=args.cuda,
config_enum=args.enum_discrete,
aux_loss_multiplier=args.aux_loss_multiplier)
# setup the optimizer
adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
optimizer = Adam(adam_params)
# set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
# by enumerating each class label for the sampled discrete categorical distribution in the model
guide = config_enumerate(ss_vae.guide, args.enum_discrete, expand=True)
elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_plate_nesting=1)
loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)
# build a list of all losses considered
losses = [loss_basic]
# aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
if args.aux_loss:
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo)
losses.append(loss_aux)
try:
# setup the logger if a filename is provided
logger = open(args.logfile, "w") if args.logfile else None
data_loaders = setup_data_loaders(MNISTCached, args.cuda, args.batch_size, sup_num=args.sup_num)
# how often would a supervised batch be encountered during inference
# e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
# until we have traversed through the all supervised batches
periodic_interval_batches = int(MNISTCached.train_data_size / (1.0 * args.sup_num))
# number of unsupervised examples
unsup_num = MNISTCached.train_data_size - args.sup_num
# initializing local variables to maintain the best validation accuracy
# seen across epochs over the supervised training set
# and the corresponding testing set and the state of the networks
best_valid_acc, corresponding_test_acc = 0.0, 0.0
# run inference for a certain number of epochs
for i in range(0, args.num_epochs):
# get the losses for an epoch
epoch_losses_sup, epoch_losses_unsup = \
run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)
# compute average epoch losses i.e. losses per example
avg_epoch_losses_sup = map(lambda v: v / args.sup_num, epoch_losses_sup)
avg_epoch_losses_unsup = map(lambda v: v / unsup_num, epoch_losses_unsup)
# store the loss and validation/testing accuracies in the logfile
str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))
str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup))
validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, args.batch_size)
str_print += " validation accuracy {}".format(validation_accuracy)
# this test accuracy is only for logging, this is not used
# to make any decisions during training
test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size)
str_print += " test accuracy {}".format(test_accuracy)
# update the best validation accuracy and the corresponding
# testing accuracy and the state of the parent module (including the networks)
if best_valid_acc < validation_accuracy:
best_valid_acc = validation_accuracy
corresponding_test_acc = test_accuracy
print_and_log(logger, str_print)
final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size)
print_and_log(logger, "best validation accuracy {} corresponding testing accuracy {} "
"last testing accuracy {}".format(best_valid_acc, corresponding_test_acc, final_test_accuracy))
# visualize the conditional samples
visualize(ss_vae, viz, data_loaders["test"])
finally:
# close the logger file object if we opened it earlier
if args.logfile:
logger.close()
EXAMPLE_RUN = "example run: python ss_vae_M2.py --seed 0 --cuda -n 2 --aux-loss -alm 46 -enum parallel " \
"-sup 3000 -zd 50 -hl 500 -lr 0.00042 -b1 0.95 -bs 200 -log ./tmp.log"
if __name__ == "__main__":
assert pyro.__version__.startswith('1.4.0')
parser = argparse.ArgumentParser(description="SS-VAE\n{}".format(EXAMPLE_RUN))
parser.add_argument('--cuda', action='store_true',
help="use GPU(s) to speed up training")
parser.add_argument('--jit', action='store_true',
help="use PyTorch jit to speed up training")
parser.add_argument('-n', '--num-epochs', default=50, type=int,
help="number of epochs to run")
parser.add_argument('--aux-loss', action="store_true",
help="whether to use the auxiliary loss from NIPS 14 paper "
"(Kingma et al). It is not used by default ")
parser.add_argument('-alm', '--aux-loss-multiplier', default=46, type=float,
help="the multiplier to use with the auxiliary loss")
parser.add_argument('-enum', '--enum-discrete', default="parallel",
help="parallel, sequential or none. uses parallel enumeration by default")
parser.add_argument('-sup', '--sup-num', default=3000,
type=float, help="supervised amount of the data i.e. "
"how many of the images have supervised labels")
parser.add_argument('-zd', '--z-dim', default=50, type=int,
help="size of the tensor representing the latent variable z "
"variable (handwriting style for our MNIST dataset)")
parser.add_argument('-hl', '--hidden-layers', nargs='+', default=[500], type=int,
help="a tuple (or list) of MLP layers to be used in the neural networks "
"representing the parameters of the distributions in our model")
parser.add_argument('-lr', '--learning-rate', default=0.00042, type=float,
help="learning rate for Adam optimizer")
parser.add_argument('-b1', '--beta-1', default=0.9, type=float,
help="beta-1 parameter for Adam optimizer")
parser.add_argument('-bs', '--batch-size', default=200, type=int,
help="number of images (and labels) to be considered in a batch")
parser.add_argument('-log', '--logfile', default="./tmp.log", type=str,
help="filename for logging the outputs")
parser.add_argument('--seed', default=None, type=int,
help="seed for controlling randomness in this example")
parser.add_argument('--visualize', action="store_true",
help="use a visdom server to visualize the embeddings")
args = parser.parse_args()
# some assertions to make sure that batching math assumptions are met
assert args.sup_num % args.batch_size == 0, "assuming simplicity of batching math"
assert MNISTCached.validation_size % args.batch_size == 0, \
"batch size should divide the number of validation examples"
assert MNISTCached.train_data_size % args.batch_size == 0, \
"batch size doesn't divide total number of training data examples"
assert MNISTCached.test_size % args.batch_size == 0, "batch size should divide the number of test examples"
main(args)