def forward()

in botorch/models/pairwise_gp.py [0:0]


    def forward(self, datapoints: Tensor) -> MultivariateNormal:
        r"""Calculate a posterior or prior prediction.

        During training mode, forward implemented solely for gradient-based
        hyperparam opt. Essentially what it does is to re-calculate the utility
        f using its analytical form at f_map so that we are able to obtain
        gradients of the hyperparameters.

        Args:
            datapoints: A `batch_shape x n x d` Tensor,
                should be the same as self.datapoints during training

        Returns:
            A MultivariateNormal object, being one of the followings:
                1. Posterior centered at MAP points for training data (training mode)
                2. Prior predictions (prior mode)
                3. Predictive posterior (eval mode)
        """

        # Training mode: optimizing
        if self.training:
            if self._has_no_data():
                raise RuntimeError(
                    "datapoints and comparisons cannot be None in training mode. "
                    "Call .eval() for prior predictions, "
                    "or call .set_train_data() to add training data."
                )

            if datapoints is not self.datapoints:
                raise RuntimeError("Must train on training data")

            transformed_dp = self.transform_inputs(datapoints)

            # We pass in the untransformed datapoints into set_train_data
            # as we will be setting self.datapoints as the untransformed datapoints
            # self.transform_inputs will be called inside before calling _update()
            self.set_train_data(datapoints, self.comparisons, update_model=True)

            # Take a newton step on the posterior MAP point to fill
            # in gradients for pytorch
            self.utility = self._util_newton_updates(
                transformed_dp, self.utility, max_iter=1
            )

            hl = self.likelihood_hess = self._hess_likelihood_f_sum(
                self.utility, self.D, self.DT
            )
            covar = self.covar
            # Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_
            # (A + B)^-1 = A^-1 - A^-1 @ (I + BA^-1)^-1 @ BA^-1
            # where A = covar_inv, B = hl
            hl_cov = hl @ covar
            eye = torch.eye(
                hl_cov.size(-1),
                dtype=self.datapoints.dtype,
                device=self.datapoints.device,
            ).expand(hl_cov.shape)
            hl_cov_I = hl_cov + eye  # add I to hl_cov
            train_covar_map = covar - covar @ torch.linalg.solve(hl_cov_I, hl_cov)
            output_mean, output_covar = self.utility, train_covar_map

        # Prior mode
        elif settings.prior_mode.on() or self._has_no_data():
            transformed_new_dp = self.transform_inputs(datapoints)
            # if we don't have any data yet, use prior GP to make predictions
            output_mean, output_covar = self._prior_predict(transformed_new_dp)

        # Posterior mode
        else:
            transformed_dp = self.transform_inputs(self.datapoints)
            transformed_new_dp = self.transform_inputs(datapoints).to(transformed_dp)

            # self.utility might be None if exception was raised and _update
            # was failed to be called during hyperparameter optimization
            # procedures (e.g., fit_gpytorch_scipy)
            if self.utility is None:
                self._update(transformed_dp)

            if self.pred_cov_fac_need_update:
                self._update_utility_derived_values()

            X, X_new = self._transform_batch_shape(transformed_dp, transformed_new_dp)
            covar_chol, _ = self._transform_batch_shape(self.covar_chol, X_new)
            hl, _ = self._transform_batch_shape(self.likelihood_hess, X_new)
            hlcov_eye, _ = self._transform_batch_shape(self.hlcov_eye, X_new)

            # otherwise compute predictive mean and covariance
            covar_xnew_x = self._calc_covar(X_new, X)
            covar_x_xnew = covar_xnew_x.transpose(-1, -2)
            covar_xnew = self._calc_covar(X_new, X_new)
            p = self.utility - self._prior_mean(X)

            covar_inv_p = torch.cholesky_solve(p.unsqueeze(-1), covar_chol)
            pred_mean = (covar_xnew_x @ covar_inv_p).squeeze(-1)
            pred_mean = pred_mean + self._prior_mean(X_new)

            # [Brochu2010tutorial]_ page 27
            # Preictive covariance fatcor: hlcov_eye = (K + C^-1)
            # fac = (K + C^-1)^-1 @ k = pred_cov_fac_inv @ covar_x_xnew
            # used substitution method here to calculate fac
            fac = torch.linalg.solve(hlcov_eye, hl @ covar_x_xnew)
            pred_covar = covar_xnew - (covar_xnew_x @ fac)

            output_mean, output_covar = pred_mean, pred_covar

        try:
            if self.datapoints is None:
                diag_jitter = torch.eye(output_covar.size(-1))
            else:
                diag_jitter = torch.eye(
                    output_covar.size(-1),
                    dtype=self.datapoints.dtype,
                    device=self.datapoints.device,
                )
            diag_jitter = diag_jitter.expand(output_covar.shape)
            diag_jitter = diag_jitter * self._jitter
            # Preemptively adding jitter to diagonal to prevent the use of _add_jitter
            # given that torch.cholesky may be very slow on non-pd matrix input
            # See https://github.com/pytorch/pytorch/issues/34272
            # TODO: remove this once torch.cholesky issue is resolved
            output_covar = output_covar + diag_jitter
            post = MultivariateNormal(output_mean, output_covar)
        except RuntimeError:
            output_covar = self._add_jitter(output_covar)
            post = MultivariateNormal(output_mean, output_covar)

        return post