tutorial/source/ekf.ipynb (265 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Kalman Filter\n",
"\n",
"Kalman filters are linear models for state estimation of dynamic systems [1]. They have been the <i>de facto</i> standard in many robotics and tracking/prediction applications because they are well suited for systems with uncertainty about an observable dynamic process. They use a \"observe, predict, correct\" paradigm to extract information from an otherwise noisy signal. In Pyro, we can build differentiable Kalman filters with learnable parameters using the `pyro.contrib.tracking` [library](http://docs.pyro.ai/en/dev/contrib.tracking.html#module-pyro.contrib.tracking.extended_kalman_filter)\n",
"\n",
"## Dynamic process\n",
"\n",
"To start, consider this simple motion model:\n",
"\n",
"$$ X_{k+1} = FX_k + \\mathbf{W}_k $$\n",
"$$ \\mathbf{Z}_k = HX_k + \\mathbf{V}_k $$\n",
"\n",
"where $k$ is the state, $X$ is the signal estimate, $Z_k$ is the observed value at timestep $k$, $\\mathbf{W}_k$ and $\\mathbf{V}_k$ are independent noise processes (ie $\\mathbb{E}[w_k v_j^T] = 0$ for all $j, k$) which we'll approximate as Gaussians. Note that the state transitions are linear."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kalman Update\n",
"At each time step, we perform a prediction for the mean and covariance:\n",
"$$ \\hat{X}_k = F\\hat{X}_{k-1}$$\n",
"$$\\hat{P}_k = FP_{k-1}F^T + Q$$\n",
"\n",
"and a correction for the measurement:\n",
"\n",
"$$ K_k = \\hat{P}_k H^T(H\\hat{P}_k H^T + R)^{-1}$$\n",
"$$ X_k = \\hat{X}_k + K_k(z_k - H\\hat{X}_k)$$\n",
"$$ P_k = (I-K_k H)\\hat{P}_k$$\n",
"\n",
"where $X$ is the position estimate, $P$ is the covariance matrix, $K$ is the Kalman Gain, and $Q$ and $R$ are covariance matrices.\n",
"\n",
"For an in-depth derivation, see \\[2\\]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nonlinear Estimation: Extended Kalman Filter\n",
"\n",
"What if our system is non-linear, eg in GPS navigation? Consider the following non-linear system:\n",
"\n",
"$$ X_{k+1} = \\mathbf{f}(X_k) + \\mathbf{W}_k $$\n",
"$$ \\mathbf{Z}_k = \\mathbf{h}(X_k) + \\mathbf{V}_k $$\n",
"\n",
"Notice that $\\mathbf{f}$ and $\\mathbf{h}$ are now (smooth) non-linear functions.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Extended Kalman Filter (EKF) attacks this problem by using a local linearization of the Kalman filter via a [Taylors Series expansion](https://en.wikipedia.org/wiki/Taylor_series).\n",
"\n",
"$$ f(X_k, k) \\approx f(x_k^R, k) + \\mathbf{H}_k(X_k - x_k^R) + \\cdots$$\n",
"\n",
"where $\\mathbf{H}_k$ is the Jacobian matrix at time $k$, $x_k^R$ is the previous optimal estimate, and we ignore the higher order terms. At each time step, we compute a Jacobian conditioned the previous predictions (this computation is handled by Pyro under the hood), and use the result to perform a prediction and update.\n",
"\n",
"Omitting the derivations, the modification to the above predictions are now:\n",
"$$ \\hat{X}_k \\approx \\mathbf{f}(X_{k-1}^R)$$\n",
"$$ \\hat{P}_k = \\mathbf{H}_\\mathbf{f}(X_{k-1})P_{k-1}\\mathbf{H}_\\mathbf{f}^T(X_{k-1}) + Q$$\n",
"\n",
"and the updates are now:\n",
"\n",
"$$ X_k \\approx \\hat{X}_k + K_k\\big(z_k - \\mathbf{h}(\\hat{X}_k)\\big)$$\n",
"$$ K_k = \\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) \\Big(\\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) + R_k\\Big)^{-1} $$\n",
"$$ P_k = \\big(I - K_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\big)\\hat{P}_K$$\n",
"\n",
"In Pyro, all we need to do is create an `EKFState` object and use its `predict` and `update` methods. Pyro will do exact inference to compute the innovations and we will use SVI to learn a MAP estimate of the position and measurement covariances.\n",
"\n",
"As an example, let's look at an object moving at near-constant velocity in 2-D in a discrete time space over 100 time steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import math\n",
"\n",
"import torch\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"from pyro.infer.autoguide import AutoDelta\n",
"from pyro.optim import Adam\n",
"from pyro.infer import SVI, Trace_ELBO, config_enumerate\n",
"from pyro.contrib.tracking.extended_kalman_filter import EKFState\n",
"from pyro.contrib.tracking.distributions import EKFDistribution\n",
"from pyro.contrib.tracking.dynamic_models import NcvContinuous\n",
"from pyro.contrib.tracking.measurements import PositionMeasurement\n",
"\n",
"smoke_test = ('CI' in os.environ)\n",
"assert pyro.__version__.startswith('1.4.0')\n",
"pyro.enable_validation(True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dt = 1e-2\n",
"num_frames = 10\n",
"dim = 4\n",
"\n",
"# Continuous model\n",
"ncv = NcvContinuous(dim, 2.0)\n",
"\n",
"# Truth trajectory\n",
"xs_truth = torch.zeros(num_frames, dim)\n",
"# initial direction\n",
"theta0_truth = 0.0\n",
"# initial state\n",
"with torch.no_grad():\n",
" xs_truth[0, :] = torch.tensor([0.0, 0.0, math.cos(theta0_truth), math.sin(theta0_truth)])\n",
" for frame_num in range(1, num_frames):\n",
" # sample independent process noise\n",
" dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))\n",
" xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's specify the measurements. Notice that we only measure the positions of the particle."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Measurements\n",
"measurements = []\n",
"mean = torch.zeros(2)\n",
"# no correlations\n",
"cov = 1e-5 * torch.eye(2)\n",
"with torch.no_grad():\n",
" # sample independent measurement noise\n",
" dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))\n",
" # compute measurement means\n",
" zs = xs_truth[:, :2] + dzs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use a [Delta autoguide](http://docs.pyro.ai/en/dev/infer.autoguide.html#autodelta) to learn MAP estimates of the position and measurement covariances. The `EKFDistribution` computes the joint log density of all of the EKF states given a tensor of sequential measurements."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model(data):\n",
" # a HalfNormal can be used here as well\n",
" R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)\n",
" Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)\n",
" # observe the measurements\n",
" pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,\n",
" Q, time_steps=num_frames),\n",
" obs=data)\n",
" \n",
"guide = AutoDelta(model) # MAP estimation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optim = pyro.optim.Adam({'lr': 2e-2})\n",
"svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))\n",
"\n",
"pyro.set_rng_seed(0)\n",
"pyro.clear_param_store()\n",
"\n",
"for i in range(250 if not smoke_test else 2):\n",
" loss = svi.step(zs)\n",
" if not i % 10:\n",
" print('loss: ', loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# retrieve states for visualization\n",
"R = guide()['pv_cov'] * torch.eye(4)\n",
"Q = guide()['measurement_cov'] * torch.eye(2)\n",
"ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)\n",
"states= ekf_dist.filter_states(zs)"
]
},
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/html"
},
"source": [
"<center><figure>\n",
" <table>\n",
" <tr>\n",
" <td> \n",
" <img src=\"_static/img/ekf_track.png\" style=\"width: 900px;\">\n",
" </td>\n",
" </tr>\n",
" </table> <center>\n",
" <figcaption> \n",
" <font size=\"+1\"><b>Figure 1:</b>True track and EKF prediction with error. </font> \n",
" </figcaption> </center>\n",
"</figure></center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
"\\[1\\] Kalman, R. E. *A New Approach to Linear Filtering and Prediction Problems.* 1960\n",
"\n",
"\\[2\\] Welch, Greg, and Bishop, Gary. *An Introduction to the Kalman Filter.* 2006.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}