in examples/air/main.py [0:0]
def main(**kwargs):
args = argparse.Namespace(**kwargs)
if 'save' in args:
if os.path.exists(args.save):
raise RuntimeError('Output file "{}" already exists.'.format(args.save))
if args.seed is not None:
pyro.set_rng_seed(args.seed)
X, true_counts = load_data()
X_size = X.size(0)
if args.cuda:
X = X.cuda()
# Build a function to compute z_pres prior probabilities.
if args.z_pres_prior_raw:
def base_z_pres_prior_p(t):
return args.z_pres_prior
else:
base_z_pres_prior_p = make_prior(args.z_pres_prior)
# Wrap with logic to apply any annealing.
def z_pres_prior_p(opt_step, time_step):
p = base_z_pres_prior_p(time_step)
if args.anneal_prior == 'none':
return p
else:
decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
args.anneal_prior_duration, opt_step)
model_arg_keys = ['window_size',
'rnn_hidden_size',
'decoder_output_bias',
'decoder_output_use_sigmoid',
'baseline_scalar',
'encoder_net',
'decoder_net',
'predict_net',
'embed_net',
'bl_predict_net',
'non_linearity',
'pos_prior_mean',
'pos_prior_sd',
'scale_prior_mean',
'scale_prior_sd']
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
air = AIR(
num_steps=args.model_steps,
x_size=50,
use_masking=not args.no_masking,
use_baselines=not args.no_baselines,
z_what_size=args.encoder_latent_size,
use_cuda=args.cuda,
**model_args
)
if args.verbose:
print(air)
print(args)
if 'load' in args:
print('Loading parameters...')
air.load_state_dict(torch.load(args.load))
# Viz sample from prior.
if args.viz:
vis = visdom.Visdom(env=args.visdom_env)
z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))
def isBaselineParam(module_name, param_name):
return 'bl_' in module_name or 'bl_' in param_name
def per_param_optim_args(module_name, param_name):
lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate
return {'lr': lr}
adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)
# Do inference.
t0 = time.time()
examples_to_viz = X[5:10]
for i in range(1, args.num_steps + 1):
loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))
if args.progress_every > 0 and i % args.progress_every == 0:
print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
i,
(i * args.batch_size) / X_size,
(time.time() - t0) / 3600,
loss / X_size))
if args.viz and i % args.viz_every == 0:
trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
z_wheres = tensor_to_objs(latents_to_tensor(z))
# Show data with inferred objection positions.
vis.images(draw_many(examples_to_viz, z_wheres))
# Show reconstructions of data.
vis.images(draw_many(recons, z_wheres))
if args.eval_every > 0 and i % args.eval_every == 0:
# Measure accuracy on subset of training data.
acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
if args.viz and error_ix.size(0) > 0:
vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
opts=dict(caption='errors ({})'.format(i)))
if 'save' in args and i % args.save_every == 0:
print('Saving parameters...')
torch.save(air.state_dict(), args.save)