def main()

in examples/neutra.py [0:0]


def main(args):
    pyro.set_rng_seed(args.rng_seed)
    fig = plt.figure(figsize=(8, 16), constrained_layout=True)
    gs = GridSpec(4, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])
    ax7 = fig.add_subplot(gs[3, 0])
    ax8 = fig.add_subplot(gs[3, 1])
    xlim = tuple(int(x) for x in args.x_lim.strip().split(','))
    ylim = tuple(int(x) for x in args.y_lim.strip().split(','))
    assert len(xlim) == 2
    assert len(ylim) == 2

    # 1. Plot samples drawn from BananaShaped distribution
    x1, x2 = torch.meshgrid([torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)])
    d = BananaShaped(args.param_a, args.param_b)
    p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1)))
    ax1.contourf(x1, x2, p, cmap='OrRd',)
    ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='BananaShaped distribution: \nlog density')

    # 2. Run vanilla HMC
    logging.info('\nDrawing samples using vanilla HMC ...')
    mcmc = run_hmc(args, model)
    vanilla_samples = mcmc.get_samples()['x'].cpu().numpy()
    ax2.contourf(x1, x2, p, cmap='OrRd')
    ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior \n(vanilla HMC)')
    sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)

    # 3(a). Fit a diagonal normal autoguide
    logging.info('\nFitting a DiagNormal autoguide ...')
    guide = AutoDiagonalNormal(model, init_scale=0.05)
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax3.contourf(x1, x2, p, cmap='OrRd')
    ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior \n(DiagNormal autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)

    # 3(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
    ax4.set(xlabel='x0', ylabel='x1',
            title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax5.contourf(x1, x2, p, cmap='OrRd')
    ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)

    # 4(a). Fit a BNAF autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, block_autoregressive))
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax6.contourf(x1, x2, p, cmap='OrRd')
    ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior \n(BNAF autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)

    # 4(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
    ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax8.contourf(x1, x2, p, cmap='OrRd')
    ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior (transformed) \n(BNAF + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)

    plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))