in botorch/acquisition/multi_objective/analytic.py [0:0]
def forward(self, X: Tensor) -> Tensor:
posterior = self.objective(self.model.posterior(X))
mu = posterior.mean
sigma = posterior.variance.clamp_min(1e-9).sqrt()
# clamp here, since upper_bounds will contain `inf`s, which
# are not differentiable
cell_upper_bounds = self.cell_upper_bounds.clamp_max(
1e10 if X.dtype == torch.double else 1e8
)
# Compute psi(lower_i, upper_i, mu_i, sigma_i) for i=0, ... m-2
psi_lu = self.psi(
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
)
# Compute psi(lower_m, lower_m, mu_m, sigma_m)
psi_ll = self.psi(
lower=self.cell_lower_bounds,
upper=self.cell_lower_bounds,
mu=mu,
sigma=sigma,
)
# Compute nu(lower_m, upper_m, mu_m, sigma_m)
nu = self.nu(
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
)
# compute the difference psi_ll - psi_lu
psi_diff = psi_ll - psi_lu
# this is batch_shape x num_cells x 2 x (m-1)
stacked_factors = torch.stack([psi_diff, nu], dim=-2)
# Take the cross product of psi_diff and nu across all outcomes
# e.g. for m = 2
# for each batch and cell, compute
# [psi_diff_0, psi_diff_1]
# [nu_0, psi_diff_1]
# [psi_diff_0, nu_1]
# [nu_0, nu_1]
# this tensor has shape: `batch_shape x num_cells x 2^m x m`
all_factors_up_to_last = stacked_factors.gather(
dim=-2,
index=self._cross_product_indices.expand(
stacked_factors.shape[:-2] + self._cross_product_indices.shape
),
)
# compute product for all 2^m terms,
# sum across all terms and hypercells
return all_factors_up_to_last.prod(dim=-1).sum(dim=-1).sum(dim=-1)