tutorial/source/easyguide.ipynb (322 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Writing guides using EasyGuide\n",
"\n",
"This tutorial describes the [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html) module. This tutorial assumes the reader is already familiar with [SVI](http://pyro.ai/examples/svi_part_ii.html) and [tensor shapes](http://pyro.ai/examples/tensor_shapes.html).\n",
"\n",
"#### Summary\n",
"\n",
"- For simple black-box guides, try using components in [pyro.infer.autoguide](http://docs.pyro.ai/en/stable/infer.autoguide.html).\n",
"- For more complex guides, try using components in [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html).\n",
"- Decorate with `@easy_guide(model)`.\n",
"- Select multiple model sites using `group = self.group(match=\"my_regex\")`.\n",
"- Guide a group of sites by a single distribution using `group.sample(...)`.\n",
"- Inspect concatenated group shape using `group.batch_shape`, `group.event_shape`, etc.\n",
"- Use `self.plate(...)` instead of `pyro.plate(...)`.\n",
"- To be compatible with subsampling, pass the `event_dim` arg to `pyro.param(...)`.\n",
"- To MAP estimate model site \"foo\", use `foo = self.map_estimate(\"foo\")`.\n",
"\n",
"#### Table of contents\n",
"\n",
"- [Modeling time series data](#Modeling-time-series-data)\n",
"- [Writing a guide without EasyGuide](#Writing-a-guide-without-EasyGuide)\n",
"- [Using EasyGuide](#Using-EasyGuide)\n",
"- [Amortized guides](#Amortized-guides)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"from pyro.infer import SVI, Trace_ELBO\n",
"from pyro.contrib.easyguide import easy_guide\n",
"from pyro.optim import Adam\n",
"from torch.distributions import constraints\n",
"\n",
"pyro.enable_validation(True)\n",
"smoke_test = ('CI' in os.environ)\n",
"assert pyro.__version__.startswith('1.4.0')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Modeling time series data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider a time-series model with a slowly-varying continuous latent state and Bernoulli observations with a logistic link function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model(batch, subsample, full_size):\n",
" batch = list(batch)\n",
" num_time_steps = len(batch)\n",
" drift = pyro.sample(\"drift\", dist.LogNormal(-1, 0.5))\n",
" with pyro.plate(\"data\", full_size, subsample=subsample):\n",
" z = 0.\n",
" for t in range(num_time_steps):\n",
" z = pyro.sample(\"state_{}\".format(t),\n",
" dist.Normal(z, drift))\n",
" batch[t] = pyro.sample(\"obs_{}\".format(t),\n",
" dist.Bernoulli(logits=z),\n",
" obs=batch[t])\n",
" return torch.stack(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's generate some data directly from the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"full_size = 100\n",
"num_time_steps = 7\n",
"pyro.set_rng_seed(123456789)\n",
"data = model([None] * num_time_steps, torch.arange(full_size), full_size)\n",
"assert data.shape == (num_time_steps, full_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Writing a guide without EasyGuide\n",
"\n",
"Consider a possible guide for this model where we point-estimate the `drift` parameter using a `Delta` distribution, and then model local time series using shared uncertainty but local means, using a `LowRankMultivariateNormal` distribution. There is a single global sample site which we can model with a `param` and `sample` statement. Then we sample a global pair of uncertainty parameters `cov_diag` and `cov_factor`. Next we sample a local `loc` parameter using `pyro.param(..., event_dim=...)` and an auxiliary sample site. Finally we unpack that auxiliary site into one element per time series. The auxiliary-unpacked-to-`Delta`s pattern is quite common."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rank = 3\n",
" \n",
"def guide(batch, subsample, full_size):\n",
" num_time_steps, batch_size = batch.shape\n",
"\n",
" # MAP estimate the drift.\n",
" drift_loc = pyro.param(\"drift_loc\", lambda: torch.tensor(0.1),\n",
" constraint=constraints.positive)\n",
" pyro.sample(\"drift\", dist.Delta(drift_loc))\n",
"\n",
" # Model local states using shared uncertainty + local mean.\n",
" cov_diag = pyro.param(\"state_cov_diag\",\n",
" lambda: torch.full((num_time_steps,), 0.01),\n",
" constraint=constraints.positive)\n",
" cov_factor = pyro.param(\"state_cov_factor\",\n",
" lambda: torch.randn(num_time_steps, rank) * 0.01)\n",
" with pyro.plate(\"data\", full_size, subsample=subsample):\n",
" # Sample local mean.\n",
" loc = pyro.param(\"state_loc\",\n",
" lambda: torch.full((full_size, num_time_steps), 0.5),\n",
" event_dim=1)\n",
" states = pyro.sample(\"states\",\n",
" dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),\n",
" infer={\"is_auxiliary\": True})\n",
" # Unpack the joint states into one sample site per time step.\n",
" for t in range(num_time_steps):\n",
" pyro.sample(\"state_{}\".format(t), dist.Delta(states[:, t]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's train using [SVI](http://docs.pyro.ai/en/stable/inference_algos.html#module-pyro.infer.svi) and [Trace_ELBO](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.trace_elbo.Trace_ELBO), manually batching data into small minibatches."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):\n",
" full_size = data.size(-1)\n",
" pyro.get_param_store().clear()\n",
" pyro.set_rng_seed(123456789)\n",
" svi = SVI(model, guide, Adam({\"lr\": 0.02}), Trace_ELBO())\n",
" for epoch in range(num_epochs):\n",
" pos = 0\n",
" losses = []\n",
" while pos < full_size:\n",
" subsample = torch.arange(pos, pos + batch_size)\n",
" batch = data[:, pos:pos + batch_size]\n",
" pos += batch_size\n",
" losses.append(svi.step(batch, subsample, full_size=full_size))\n",
" epoch_loss = sum(losses) / len(losses)\n",
" if epoch % 10 == 0:\n",
" print(\"epoch {} loss = {}\".format(epoch, epoch_loss / data.numel()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train(guide)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using EasyGuide\n",
"\n",
"Now let's simplify using the `@easy_guide` decorator. Our modifications are:\n",
"1. Decorate with `@easy_guide` and add `self` to args.\n",
"2. Replace the `Delta` guide for drift with a simple `map_estimate()`.\n",
"3. Select a `group` of model sites and read their concatenated `event_shape`.\n",
"4. Replace the auxiliary site and `Delta` slices with a single `group.sample()`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@easy_guide(model)\n",
"def guide(self, batch, subsample, full_size):\n",
" # MAP estimate the drift.\n",
" self.map_estimate(\"drift\")\n",
"\n",
" # Model local states using shared uncertainty + local mean.\n",
" group = self.group(match=\"state_[0-9]*\") # Selects all local variables.\n",
" cov_diag = pyro.param(\"state_cov_diag\",\n",
" lambda: torch.full(group.event_shape, 0.01),\n",
" constraint=constraints.positive)\n",
" cov_factor = pyro.param(\"state_cov_factor\",\n",
" lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n",
" with self.plate(\"data\", full_size, subsample=subsample):\n",
" # Sample local mean.\n",
" loc = pyro.param(\"state_loc\",\n",
" lambda: torch.full((full_size,) + group.event_shape, 0.5),\n",
" event_dim=1)\n",
" # Automatically sample the joint latent, then unpack and replay model sites.\n",
" group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note we've used `group.event_shape` to determine the total flattened concatenated shape of all matched sites in the group."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train(guide)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Amortized guides\n",
"\n",
"`EasyGuide` also makes it easy to write amortized guides (guides where we learn a function that predicts latent variables from data, rather than learning one parameter per datapoint). Let's modify the last guide to predict the latent `loc` as an affine function of observed data, rather than memorizing each data point's latent variable. This amortized guide is more useful in practice because it can handle new data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@easy_guide(model)\n",
"def guide(self, batch, subsample, full_size):\n",
" num_time_steps, batch_size = batch.shape\n",
" self.map_estimate(\"drift\")\n",
"\n",
" group = self.group(match=\"state_[0-9]*\")\n",
" cov_diag = pyro.param(\"state_cov_diag\",\n",
" lambda: torch.full(group.event_shape, 0.01),\n",
" constraint=constraints.positive)\n",
" cov_factor = pyro.param(\"state_cov_factor\",\n",
" lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n",
"\n",
" # Predict latent propensity as an affine function of observed data.\n",
" if not hasattr(self, \"nn\"):\n",
" self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())\n",
" self.nn.weight.data.fill_(1.0 / num_time_steps)\n",
" self.nn.bias.data.fill_(-0.5)\n",
" pyro.module(\"state_nn\", self.nn)\n",
" with self.plate(\"data\", full_size, subsample=subsample):\n",
" loc = self.nn(batch.t())\n",
" group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train(guide)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}