examples/air/main.py (246 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 """ AIR applied to the multi-mnist data set [1]. [1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene understanding with generative models." Advances in Neural Information Processing Systems. 2016. """ import argparse import math import os import time from functools import partial import numpy as np import torch import visdom import pyro import pyro.contrib.examples.multi_mnist as multi_mnist import pyro.optim as optim import pyro.poutine as poutine from air import AIR, latents_to_tensor from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO from viz import draw_many, tensor_to_objs def count_accuracy(X, true_counts, air, batch_size): assert X.size(0) == true_counts.size(0), 'Size mismatch.' assert X.size(0) % batch_size == 0, 'Input size must be multiple of batch_size.' counts = torch.LongTensor(3, 4).zero_() error_latents = [] error_indicators = [] def count_vec_to_mat(vec, max_index): out = torch.LongTensor(vec.size(0), max_index + 1).zero_() out.scatter_(1, vec.type(torch.LongTensor).view(vec.size(0), 1), 1) return out for i in range(X.size(0) // batch_size): X_batch = X[i * batch_size:(i + 1) * batch_size] true_counts_batch = true_counts[i * batch_size:(i + 1) * batch_size] z_where, z_pres = air.guide(X_batch, batch_size) inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data true_counts_m = count_vec_to_mat(true_counts_batch, 2) inferred_counts_m = count_vec_to_mat(inferred_counts, 3) counts += torch.mm(true_counts_m.t(), inferred_counts_m) error_ind = 1 - (true_counts_batch == inferred_counts) error_ix = error_ind.nonzero().squeeze() error_latents.append(latents_to_tensor((z_where, z_pres)).index_select(0, error_ix)) error_indicators.append(error_ind) acc = counts.diag().sum().float() / X.size(0) error_indices = torch.cat(error_indicators).nonzero().squeeze() if X.is_cuda: error_indices = error_indices.cuda() return acc, counts, torch.cat(error_latents), error_indices # Defines something like a truncated geometric. Like the geometric, # this has the property that there's a constant difference in log prob # between p(steps=n) and p(steps=n+1). def make_prior(k): assert 0 < k <= 1 u = 1 / (1 + k + k**2 + k**3) p0 = 1 - u p1 = 1 - (k * u) / p0 p2 = 1 - (k**2 * u) / (p0 * p1) trial_probs = [p0, p1, p2] # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2] # print(dist) return lambda t: trial_probs[t] # Implements "prior annealing" as described in this blog post: # http://akosiorek.github.io/ml/2017/09/03/implementing-air.html # That implementation does something very close to the following: # --z-pres-prior (1 - 1e-15) # --z-pres-prior-raw # --anneal-prior exp # --anneal-prior-to 1e-7 # --anneal-prior-begin 1000 # --anneal-prior-duration 1e6 # e.g. After 200K steps z_pres_p will have decayed to ~0.04 # These compute the value of a decaying value at time t. # initial: initial value # final: final value, reached after begin + duration steps # begin: number of steps before decay begins # duration: number of steps over which decay occurs # t: current time step def lin_decay(initial, final, begin, duration, t): assert duration > 0 x = (final - initial) * (t - begin) / duration + initial return max(min(x, initial), final) def exp_decay(initial, final, begin, duration, t): assert final > 0 assert duration > 0 # half_life = math.log(2) / math.log(initial / final) * duration decay_rate = math.log(initial / final) / duration x = initial * math.exp(-decay_rate * (t - begin)) return max(min(x, initial), final) def load_data(): inpath = get_data_directory(__file__) X_np, Y = multi_mnist.load(inpath) X_np = X_np.astype(np.float32) X_np /= 255.0 X = torch.from_numpy(X_np) # Using FloatTensor to allow comparison with values sampled from # Bernoulli. counts = torch.FloatTensor([len(objs) for objs in Y]) return X, counts def main(**kwargs): args = argparse.Namespace(**kwargs) if 'save' in args: if os.path.exists(args.save): raise RuntimeError('Output file "{}" already exists.'.format(args.save)) if args.seed is not None: pyro.set_rng_seed(args.seed) X, true_counts = load_data() X_size = X.size(0) if args.cuda: X = X.cuda() # Build a function to compute z_pres prior probabilities. if args.z_pres_prior_raw: def base_z_pres_prior_p(t): return args.z_pres_prior else: base_z_pres_prior_p = make_prior(args.z_pres_prior) # Wrap with logic to apply any annealing. def z_pres_prior_p(opt_step, time_step): p = base_z_pres_prior_p(time_step) if args.anneal_prior == 'none': return p else: decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior] return decay(p, args.anneal_prior_to, args.anneal_prior_begin, args.anneal_prior_duration, opt_step) model_arg_keys = ['window_size', 'rnn_hidden_size', 'decoder_output_bias', 'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net', 'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net', 'non_linearity', 'pos_prior_mean', 'pos_prior_sd', 'scale_prior_mean', 'scale_prior_sd'] model_args = {key: getattr(args, key) for key in model_arg_keys if key in args} air = AIR( num_steps=args.model_steps, x_size=50, use_masking=not args.no_masking, use_baselines=not args.no_baselines, z_what_size=args.encoder_latent_size, use_cuda=args.cuda, **model_args ) if args.verbose: print(air) print(args) if 'load' in args: print('Loading parameters...') air.load_state_dict(torch.load(args.load)) # Viz sample from prior. if args.viz: vis = visdom.Visdom(env=args.visdom_env) z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0)) vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z)))) def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name def per_param_optim_args(module_name, param_name): lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate return {'lr': lr} adam = optim.Adam(per_param_optim_args) elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO() svi = SVI(air.model, air.guide, adam, loss=elbo) # Do inference. t0 = time.time() examples_to_viz = X[5:10] for i in range(1, args.num_steps + 1): loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)) if args.progress_every > 0 and i % args.progress_every == 0: print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600, loss / X_size)) if args.viz and i % args.viz_every == 0: trace = poutine.trace(air.guide).get_trace(examples_to_viz, None) z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0)) z_wheres = tensor_to_objs(latents_to_tensor(z)) # Show data with inferred objection positions. vis.images(draw_many(examples_to_viz, z_wheres)) # Show reconstructions of data. vis.images(draw_many(recons, z_wheres)) if args.eval_every > 0 and i % args.eval_every == 0: # Measure accuracy on subset of training data. acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000) print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist())) if args.viz and error_ix.size(0) > 0: vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])), opts=dict(caption='errors ({})'.format(i))) if 'save' in args and i % args.save_every == 0: print('Saving parameters...') torch.save(air.state_dict(), args.save) if __name__ == '__main__': assert pyro.__version__.startswith('1.4.0') parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS) parser.add_argument('-n', '--num-steps', type=int, default=int(1e8), help='number of optimization steps to take') parser.add_argument('-b', '--batch-size', type=int, default=64, help='batch size') parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='learning rate') parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3, help='baseline learning rate') parser.add_argument('--progress-every', type=int, default=1, help='number of steps between writing progress to stdout') parser.add_argument('--eval-every', type=int, default=0, help='number of steps between evaluations') parser.add_argument('--baseline-scalar', type=float, help='scale the output of the baseline nets by this value') parser.add_argument('--no-baselines', action='store_true', default=False, help='do not use data dependent baselines') parser.add_argument('--encoder-net', type=int, nargs='+', default=[200], help='encoder net hidden layer sizes') parser.add_argument('--decoder-net', type=int, nargs='+', default=[200], help='decoder net hidden layer sizes') parser.add_argument('--predict-net', type=int, nargs='+', help='predict net hidden layer sizes') parser.add_argument('--embed-net', type=int, nargs='+', help='embed net architecture') parser.add_argument('--bl-predict-net', type=int, nargs='+', help='baseline predict net hidden layer sizes') parser.add_argument('--non-linearity', type=str, help='non linearity to use throughout') parser.add_argument('--viz', action='store_true', default=False, help='generate vizualizations during optimization') parser.add_argument('--viz-every', type=int, default=100, help='number of steps between vizualizations') parser.add_argument('--visdom-env', default='main', help='visdom enviroment name') parser.add_argument('--load', type=str, help='load previously saved parameters') parser.add_argument('--save', type=str, help='save parameters to specified file') parser.add_argument('--save-every', type=int, default=1e4, help='number of steps between parameter saves') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') parser.add_argument('--jit', action='store_true', default=False, help='use PyTorch jit') parser.add_argument('-t', '--model-steps', type=int, default=3, help='number of time steps') parser.add_argument('--rnn-hidden-size', type=int, default=256, help='rnn hidden size') parser.add_argument('--encoder-latent-size', type=int, default=50, help='attention window encoder/decoder latent space size') parser.add_argument('--decoder-output-bias', type=float, help='bias added to decoder output (prior to applying non-linearity)') parser.add_argument('--decoder-output-use-sigmoid', action='store_true', help='apply sigmoid function to output of decoder network') parser.add_argument('--window-size', type=int, default=28, help='attention window size') parser.add_argument('--z-pres-prior', type=float, default=0.5, help='prior success probability for z_pres') parser.add_argument('--z-pres-prior-raw', action='store_true', default=False, help='use --z-pres-prior directly as success prob instead of a geometric like prior') parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none', help='anneal z_pres prior during optimization') parser.add_argument('--anneal-prior-to', type=float, default=1e-7, help='target z_pres prior prob') parser.add_argument('--anneal-prior-begin', type=int, default=0, help='number of steps to wait before beginning to anneal the prior') parser.add_argument('--anneal-prior-duration', type=int, default=100000, help='number of steps over which to anneal the prior') parser.add_argument('--pos-prior-mean', type=float, help='mean of the window position prior') parser.add_argument('--pos-prior-sd', type=float, help='std. dev. of the window position prior') parser.add_argument('--scale-prior-mean', type=float, help='mean of the window scale prior') parser.add_argument('--scale-prior-sd', type=float, help='std. dev. of the window scale prior') parser.add_argument('--no-masking', action='store_true', default=False, help='do not mask out the costs of unused choices') parser.add_argument('--seed', type=int, help='random seed', default=None) parser.add_argument('-v', '--verbose', action='store_true', default=False, help='write hyper parameters and network architecture to stdout') main(**vars(parser.parse_args()))