in examples/vae/vae.py [0:0]
def main(args):
# clear param store
pyro.clear_param_store()
# setup MNIST data loaders
# train_loader, test_loader
train_loader, test_loader = setup_data_loaders(MNIST, use_cuda=args.cuda, batch_size=256)
# setup the VAE
vae = VAE(use_cuda=args.cuda)
# setup the optimizer
adam_args = {"lr": args.learning_rate}
optimizer = Adam(adam_args)
# setup the inference algorithm
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)
# setup visdom for visualization
if args.visdom_flag:
vis = visdom.Visdom()
train_elbo = []
test_elbo = []
# training loop
for epoch in range(args.num_epochs):
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch x returned
# by the data loader
for x, _ in train_loader:
# if on GPU put mini-batch into CUDA memory
if args.cuda:
x = x.cuda()
# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x)
# report training diagnostics
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss / normalizer_train
train_elbo.append(total_epoch_loss_train)
print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
if epoch % args.test_frequency == 0:
# initialize loss accumulator
test_loss = 0.
# compute the loss over the entire test set
for i, (x, _) in enumerate(test_loader):
# if on GPU put mini-batch into CUDA memory
if args.cuda:
x = x.cuda()
# compute ELBO estimate and accumulate loss
test_loss += svi.evaluate_loss(x)
# pick three random test images from the first mini-batch and
# visualize how well we're reconstructing them
if i == 0:
if args.visdom_flag:
plot_vae_samples(vae, vis)
reco_indices = np.random.randint(0, x.shape[0], 3)
for index in reco_indices:
test_img = x[index, :]
reco_img = vae.reconstruct_img(test_img)
vis.image(test_img.reshape(28, 28).detach().cpu().numpy(),
opts={'caption': 'test image'})
vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(),
opts={'caption': 'reconstructed image'})
# report test diagnostics
normalizer_test = len(test_loader.dataset)
total_epoch_loss_test = test_loss / normalizer_test
test_elbo.append(total_epoch_loss_test)
print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
if epoch == args.tsne_iter:
mnist_test_tsne(vae=vae, test_loader=test_loader)
plot_llk(np.array(train_elbo), np.array(test_elbo))
return vae