in pyro/contrib/gp/models/sgpr.py [0:0]
def forward(self, Xnew, full_cov=False, noiseless=True):
r"""
Computes the mean and covariance matrix (or variance) of Gaussian Process
posterior on a test input data :math:`X_{new}`:
.. math:: p(f^* \mid X_{new}, X, y, k, X_u, \epsilon) = \mathcal{N}(loc, cov).
.. note:: The noise parameter ``noise`` (:math:`\epsilon`), the inducing-point
parameter ``Xu``, together with kernel's parameters have been learned from
a training procedure (MCMC or SVI).
:param torch.Tensor Xnew: A input data for testing. Note that
``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``.
:param bool full_cov: A flag to decide if we want to predict full covariance
matrix or just variance.
:param bool noiseless: A flag to decide if we want to include noise in the
prediction output or not.
:returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
self._check_Xnew_shape(Xnew)
self.set_mode("guide")
# W = inv(Luu) @ Kuf
# Ws = inv(Luu) @ Kus
# D as in self.model()
# K = I + W @ inv(D) @ W.T = L @ L.T
# S = inv[Kuu + Kuf @ inv(D) @ Kfu]
# = inv(Luu).T @ inv[I + inv(Luu)@ Kuf @ inv(D)@ Kfu @ inv(Luu).T] @ inv(Luu)
# = inv(Luu).T @ inv[I + W @ inv(D) @ W.T] @ inv(Luu)
# = inv(Luu).T @ inv(K) @ inv(Luu)
# = inv(Luu).T @ inv(L).T @ inv(L) @ inv(Luu)
# loc = Ksu @ S @ Kuf @ inv(D) @ y = Ws.T @ inv(L).T @ inv(L) @ W @ inv(D) @ y
# cov = Kss - Ksu @ inv(Kuu) @ Kus + Ksu @ S @ Kus
# = kss - Ksu @ inv(Kuu) @ Kus + Ws.T @ inv(L).T @ inv(L) @ Ws
N = self.X.size(0)
M = self.Xu.size(0)
# TODO: cache these calculations to get faster inference
Kuu = self.kernel(self.Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.cholesky()
Kuf = self.kernel(self.Xu, self.X)
W = Kuf.triangular_solve(Luu, upper=False)[0]
D = self.noise.expand(N)
if self.approx == "FITC":
Kffdiag = self.kernel(self.X, diag=True)
Qffdiag = W.pow(2).sum(dim=0)
D = D + Kffdiag - Qffdiag
W_Dinv = W / D
K = W_Dinv.matmul(W.t()).contiguous()
K.view(-1)[::M + 1] += 1 # add identity matrix to K
L = K.cholesky()
# get y_residual and convert it into 2D tensor for packing
y_residual = self.y - self.mean_function(self.X)
y_2D = y_residual.reshape(-1, N).t()
W_Dinv_y = W_Dinv.matmul(y_2D)
# End caching ----------
Kus = self.kernel(self.Xu, Xnew)
Ws = Kus.triangular_solve(Luu, upper=False)[0]
pack = torch.cat((W_Dinv_y, Ws), dim=1)
Linv_pack = pack.triangular_solve(L, upper=False)[0]
# unpack
Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]]
Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:]
C = Xnew.size(0)
loc_shape = self.y.shape[:-1] + (C,)
loc = Linv_W_Dinv_y.t().matmul(Linv_Ws).reshape(loc_shape)
if full_cov:
Kss = self.kernel(Xnew).contiguous()
if not noiseless:
Kss.view(-1)[::C + 1] += self.noise # add noise to the diagonal
Qss = Ws.t().matmul(Ws)
cov = Kss - Qss + Linv_Ws.t().matmul(Linv_Ws)
cov_shape = self.y.shape[:-1] + (C, C)
cov = cov.expand(cov_shape)
else:
Kssdiag = self.kernel(Xnew, diag=True)
if not noiseless:
Kssdiag = Kssdiag + self.noise
Qssdiag = Ws.pow(2).sum(dim=0)
cov = Kssdiag - Qssdiag + Linv_Ws.pow(2).sum(dim=0)
cov_shape = self.y.shape[:-1] + (C,)
cov = cov.expand(cov_shape)
return loc + self.mean_function(Xnew), cov