in examples/neutra.py [0:0]
def main(args):
pyro.set_rng_seed(args.rng_seed)
fig = plt.figure(figsize=(8, 16), constrained_layout=True)
gs = GridSpec(4, 2, figure=fig)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[2, 0])
ax6 = fig.add_subplot(gs[2, 1])
ax7 = fig.add_subplot(gs[3, 0])
ax8 = fig.add_subplot(gs[3, 1])
xlim = tuple(int(x) for x in args.x_lim.strip().split(','))
ylim = tuple(int(x) for x in args.y_lim.strip().split(','))
assert len(xlim) == 2
assert len(ylim) == 2
# 1. Plot samples drawn from BananaShaped distribution
x1, x2 = torch.meshgrid([torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)])
d = BananaShaped(args.param_a, args.param_b)
p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1)))
ax1.contourf(x1, x2, p, cmap='OrRd',)
ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
title='BananaShaped distribution: \nlog density')
# 2. Run vanilla HMC
logging.info('\nDrawing samples using vanilla HMC ...')
mcmc = run_hmc(args, model)
vanilla_samples = mcmc.get_samples()['x'].cpu().numpy()
ax2.contourf(x1, x2, p, cmap='OrRd')
ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
title='Posterior \n(vanilla HMC)')
sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)
# 3(a). Fit a diagonal normal autoguide
logging.info('\nFitting a DiagNormal autoguide ...')
guide = AutoDiagonalNormal(model, init_scale=0.05)
fit_guide(guide, args)
with pyro.plate('N', args.num_samples):
guide_samples = guide()['x'].detach().cpu().numpy()
ax3.contourf(x1, x2, p, cmap='OrRd')
ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
title='Posterior \n(DiagNormal autoguide)')
sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)
# 3(b). Draw samples using NeuTra HMC
logging.info('\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...')
neutra = NeuTraReparam(guide.requires_grad_(False))
neutra_model = poutine.reparam(model, config=lambda _: neutra)
mcmc = run_hmc(args, neutra_model)
zs = mcmc.get_samples()['x_shared_latent']
sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
ax4.set(xlabel='x0', ylabel='x1',
title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)')
samples = neutra.transform_sample(zs)
samples = samples['x'].cpu().numpy()
ax5.contourf(x1, x2, p, cmap='OrRd')
ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)')
sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)
# 4(a). Fit a BNAF autoguide
logging.info('\nFitting a BNAF autoguide ...')
guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, block_autoregressive))
fit_guide(guide, args)
with pyro.plate('N', args.num_samples):
guide_samples = guide()['x'].detach().cpu().numpy()
ax6.contourf(x1, x2, p, cmap='OrRd')
ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
title='Posterior \n(BNAF autoguide)')
sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)
# 4(b). Draw samples using NeuTra HMC
logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
neutra = NeuTraReparam(guide.requires_grad_(False))
neutra_model = poutine.reparam(model, config=lambda _: neutra)
mcmc = run_hmc(args, neutra_model)
zs = mcmc.get_samples()['x_shared_latent']
sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)')
samples = neutra.transform_sample(zs)
samples = samples['x'].cpu().numpy()
ax8.contourf(x1, x2, p, cmap='OrRd')
ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
title='Posterior (transformed) \n(BNAF + NeuTra HMC)')
sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)
plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))