in botorch/models/pairwise_gp.py [0:0]
def forward(self, datapoints: Tensor) -> MultivariateNormal:
r"""Calculate a posterior or prior prediction.
During training mode, forward implemented solely for gradient-based
hyperparam opt. Essentially what it does is to re-calculate the utility
f using its analytical form at f_map so that we are able to obtain
gradients of the hyperparameters.
Args:
datapoints: A `batch_shape x n x d` Tensor,
should be the same as self.datapoints during training
Returns:
A MultivariateNormal object, being one of the followings:
1. Posterior centered at MAP points for training data (training mode)
2. Prior predictions (prior mode)
3. Predictive posterior (eval mode)
"""
# Training mode: optimizing
if self.training:
if self._has_no_data():
raise RuntimeError(
"datapoints and comparisons cannot be None in training mode. "
"Call .eval() for prior predictions, "
"or call .set_train_data() to add training data."
)
if datapoints is not self.datapoints:
raise RuntimeError("Must train on training data")
transformed_dp = self.transform_inputs(datapoints)
# We pass in the untransformed datapoints into set_train_data
# as we will be setting self.datapoints as the untransformed datapoints
# self.transform_inputs will be called inside before calling _update()
self.set_train_data(datapoints, self.comparisons, update_model=True)
# Take a newton step on the posterior MAP point to fill
# in gradients for pytorch
self.utility = self._util_newton_updates(
transformed_dp, self.utility, max_iter=1
)
hl = self.likelihood_hess = self._hess_likelihood_f_sum(
self.utility, self.D, self.DT
)
covar = self.covar
# Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_
# (A + B)^-1 = A^-1 - A^-1 @ (I + BA^-1)^-1 @ BA^-1
# where A = covar_inv, B = hl
hl_cov = hl @ covar
eye = torch.eye(
hl_cov.size(-1),
dtype=self.datapoints.dtype,
device=self.datapoints.device,
).expand(hl_cov.shape)
hl_cov_I = hl_cov + eye # add I to hl_cov
train_covar_map = covar - covar @ torch.linalg.solve(hl_cov_I, hl_cov)
output_mean, output_covar = self.utility, train_covar_map
# Prior mode
elif settings.prior_mode.on() or self._has_no_data():
transformed_new_dp = self.transform_inputs(datapoints)
# if we don't have any data yet, use prior GP to make predictions
output_mean, output_covar = self._prior_predict(transformed_new_dp)
# Posterior mode
else:
transformed_dp = self.transform_inputs(self.datapoints)
transformed_new_dp = self.transform_inputs(datapoints).to(transformed_dp)
# self.utility might be None if exception was raised and _update
# was failed to be called during hyperparameter optimization
# procedures (e.g., fit_gpytorch_scipy)
if self.utility is None:
self._update(transformed_dp)
if self.pred_cov_fac_need_update:
self._update_utility_derived_values()
X, X_new = self._transform_batch_shape(transformed_dp, transformed_new_dp)
covar_chol, _ = self._transform_batch_shape(self.covar_chol, X_new)
hl, _ = self._transform_batch_shape(self.likelihood_hess, X_new)
hlcov_eye, _ = self._transform_batch_shape(self.hlcov_eye, X_new)
# otherwise compute predictive mean and covariance
covar_xnew_x = self._calc_covar(X_new, X)
covar_x_xnew = covar_xnew_x.transpose(-1, -2)
covar_xnew = self._calc_covar(X_new, X_new)
p = self.utility - self._prior_mean(X)
covar_inv_p = torch.cholesky_solve(p.unsqueeze(-1), covar_chol)
pred_mean = (covar_xnew_x @ covar_inv_p).squeeze(-1)
pred_mean = pred_mean + self._prior_mean(X_new)
# [Brochu2010tutorial]_ page 27
# Preictive covariance fatcor: hlcov_eye = (K + C^-1)
# fac = (K + C^-1)^-1 @ k = pred_cov_fac_inv @ covar_x_xnew
# used substitution method here to calculate fac
fac = torch.linalg.solve(hlcov_eye, hl @ covar_x_xnew)
pred_covar = covar_xnew - (covar_xnew_x @ fac)
output_mean, output_covar = pred_mean, pred_covar
try:
if self.datapoints is None:
diag_jitter = torch.eye(output_covar.size(-1))
else:
diag_jitter = torch.eye(
output_covar.size(-1),
dtype=self.datapoints.dtype,
device=self.datapoints.device,
)
diag_jitter = diag_jitter.expand(output_covar.shape)
diag_jitter = diag_jitter * self._jitter
# Preemptively adding jitter to diagonal to prevent the use of _add_jitter
# given that torch.cholesky may be very slow on non-pd matrix input
# See https://github.com/pytorch/pytorch/issues/34272
# TODO: remove this once torch.cholesky issue is resolved
output_covar = output_covar + diag_jitter
post = MultivariateNormal(output_mean, output_covar)
except RuntimeError:
output_covar = self._add_jitter(output_covar)
post = MultivariateNormal(output_mean, output_covar)
return post