in botorch/models/pairwise_gp.py [0:0]
def _update(self, datapoints: Tensor, **kwargs) -> None:
r"""Update the model by updating the covar matrix and MAP utility values
Update the model by
1. Re-evaluating the covar matrix as the data or hyperparams may have changed
2. Approximating maximum a posteriori of the utility function f using fsolve
Should be called after data or hyperparameters are changed to update
f_map and related values
self._xtol and self._maxfev are passed to fsolve as xtol and maxfev
to control stopping criteria
Args:
datapoints: (transformed) datapoints for finding f_max
"""
xtol = 1e-6 if self._xtol is None else self._xtol
maxfev = 100 if self._maxfev is None else self._maxfev
# Using the latest param for covariance before calculating f_map
self._update_covar(datapoints)
# scipy newton raphson
with torch.no_grad():
# warm start
init_x0_size = self.batch_shape + torch.Size([self.n])
if self._x0 is None or torch.Size(self._x0.shape) != init_x0_size:
x0 = np.random.rand(*init_x0_size)
else:
x0 = self._x0
if len(self.batch_shape) > 0:
# batch mode, do optimize.fsolve sequentially on CPU
# TODO: enable vectorization/parallelization here
x0 = x0.reshape(-1, self.n)
dp_v = datapoints.view(-1, self.n, self.dim).cpu()
D_v = self.D.view(-1, self.m, self.n).cpu()
DT_v = self.DT.view(-1, self.n, self.m).cpu()
ch_v = self.covar_chol.view(-1, self.n, self.n).cpu()
ci_v = self.covar_inv.view(-1, self.n, self.n).cpu()
x = np.empty(x0.shape)
for i in range(x0.shape[0]):
fsolve_args = (dp_v[i], D_v[i], DT_v[i], ch_v[i], ci_v[i], True)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
x[i] = optimize.fsolve(
x0=x0[i],
func=self._grad_posterior_f,
fprime=self._hess_posterior_f,
xtol=xtol,
maxfev=maxfev,
args=fsolve_args,
**kwargs,
)
x = x.reshape(*init_x0_size)
else:
# fsolve only works on CPU
fsolve_args = (
datapoints.cpu(),
self.D.cpu(),
self.DT.cpu(),
self.covar_chol.cpu(),
self.covar_inv.cpu(),
True,
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
x = optimize.fsolve(
x0=x0,
func=self._grad_posterior_f,
fprime=self._hess_posterior_f,
xtol=xtol,
maxfev=maxfev,
args=fsolve_args,
**kwargs,
)
self._x0 = x.copy() # save for warm-starting
f = torch.tensor(x, dtype=datapoints.dtype, device=datapoints.device)
# To perform hyperparameter optimization, this need to be recalculated
# when calling forward() in order to obtain correct gradients
# self.likelihood_hess is updated here is for the rare case where we
# do not want to call forward()
self.likelihood_hess = self._hess_likelihood_f_sum(f, self.D, self.DT)
# Lazy update hlcov_eye, which is used in calculating posterior during training
self.pred_cov_fac_need_update = True
# fill in dummy values for hlcov_eye so that load_state_dict can function
hlcov_eye_size = torch.Size((*self.likelihood_hess.shape[:-2], self.n, self.n))
self.hlcov_eye = torch.empty(hlcov_eye_size)
self.utility = f.clone().requires_grad_(True)