in examples/vae/ss_vae_M2.py [0:0]
def main(args):
"""
run inference for SS-VAE
:param args: arguments for SS-VAE
:return: None
"""
if args.seed is not None:
pyro.set_rng_seed(args.seed)
viz = None
if args.visualize:
viz = Visdom()
mkdir_p("./vae_results")
# batch_size: number of images (and labels) to be considered in a batch
ss_vae = SSVAE(z_dim=args.z_dim,
hidden_layers=args.hidden_layers,
use_cuda=args.cuda,
config_enum=args.enum_discrete,
aux_loss_multiplier=args.aux_loss_multiplier)
# setup the optimizer
adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
optimizer = Adam(adam_params)
# set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
# by enumerating each class label for the sampled discrete categorical distribution in the model
guide = config_enumerate(ss_vae.guide, args.enum_discrete, expand=True)
elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_plate_nesting=1)
loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)
# build a list of all losses considered
losses = [loss_basic]
# aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
if args.aux_loss:
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo)
losses.append(loss_aux)
try:
# setup the logger if a filename is provided
logger = open(args.logfile, "w") if args.logfile else None
data_loaders = setup_data_loaders(MNISTCached, args.cuda, args.batch_size, sup_num=args.sup_num)
# how often would a supervised batch be encountered during inference
# e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
# until we have traversed through the all supervised batches
periodic_interval_batches = int(MNISTCached.train_data_size / (1.0 * args.sup_num))
# number of unsupervised examples
unsup_num = MNISTCached.train_data_size - args.sup_num
# initializing local variables to maintain the best validation accuracy
# seen across epochs over the supervised training set
# and the corresponding testing set and the state of the networks
best_valid_acc, corresponding_test_acc = 0.0, 0.0
# run inference for a certain number of epochs
for i in range(0, args.num_epochs):
# get the losses for an epoch
epoch_losses_sup, epoch_losses_unsup = \
run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)
# compute average epoch losses i.e. losses per example
avg_epoch_losses_sup = map(lambda v: v / args.sup_num, epoch_losses_sup)
avg_epoch_losses_unsup = map(lambda v: v / unsup_num, epoch_losses_unsup)
# store the loss and validation/testing accuracies in the logfile
str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))
str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup))
validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, args.batch_size)
str_print += " validation accuracy {}".format(validation_accuracy)
# this test accuracy is only for logging, this is not used
# to make any decisions during training
test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size)
str_print += " test accuracy {}".format(test_accuracy)
# update the best validation accuracy and the corresponding
# testing accuracy and the state of the parent module (including the networks)
if best_valid_acc < validation_accuracy:
best_valid_acc = validation_accuracy
corresponding_test_acc = test_accuracy
print_and_log(logger, str_print)
final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size)
print_and_log(logger, "best validation accuracy {} corresponding testing accuracy {} "
"last testing accuracy {}".format(best_valid_acc, corresponding_test_acc, final_test_accuracy))
# visualize the conditional samples
visualize(ss_vae, viz, data_loaders["test"])
finally:
# close the logger file object if we opened it earlier
if args.logfile:
logger.close()