in botorch/acquisition/multi_objective/monte_carlo.py [0:0]
def _set_cell_bounds(self, num_new_points: int) -> None:
r"""Compute the box decomposition under each posterior sample.
Args:
num_new_points: The number of new points (beyond the points
in X_baseline) that were used in the previous box decomposition.
In the first box decomposition, this should be the number of points
in X_baseline.
"""
feas = None
if self.X_baseline.shape[0] > 0:
with torch.no_grad():
posterior = self.model.posterior(self.X_baseline)
# Reset sampler, accounting for possible one-to-many transform.
self.q_in = -1
n_w = posterior.event_shape[-2] // self.X_baseline.shape[-2]
self._set_sampler(q_in=num_new_points * n_w, posterior=posterior)
# set base_sampler
self.base_sampler.register_buffer(
"base_samples", self.sampler.base_samples.detach().clone()
)
samples = self.base_sampler(posterior)
# cache posterior
if self._cache_root:
self._cache_root_decomposition(posterior=posterior)
obj = self.objective(samples, X=self.X_baseline)
if self.constraints is not None:
feas = torch.stack(
[c(samples) <= 0 for c in self.constraints], dim=0
).all(dim=0)
else:
obj = torch.empty(
*self.sampler._sample_shape,
0,
self.ref_point.shape[-1],
dtype=self.ref_point.dtype,
device=self.ref_point.device,
)
self._batch_sample_shape = obj.shape[:-2]
# collapse batch dimensions
# use numel() rather than view(-1) to handle case of no baseline points
new_batch_shape = self._batch_sample_shape.numel()
obj = obj.view(new_batch_shape, *obj.shape[-2:])
if self.constraints is not None and feas is not None:
feas = feas.view(new_batch_shape, *feas.shape[-1:])
if self.partitioning is None and not self.incremental_nehvi:
self._compute_initial_hvs(obj=obj, feas=feas)
if self.ref_point.shape[-1] > 2:
# the partitioning algorithms run faster on the CPU
# due to advanced indexing
ref_point_cpu = self.ref_point.cpu()
obj_cpu = obj.cpu()
if self.constraints is not None and feas is not None:
feas_cpu = feas.cpu()
obj_cpu = [obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0])]
partitionings = []
for sample in obj_cpu:
partitioning = self.p_class(
ref_point=ref_point_cpu, Y=sample, **self.p_kwargs
)
partitionings.append(partitioning)
self.partitioning = BoxDecompositionList(*partitionings)
else:
# use batched partitioning
obj = _pad_batch_pareto_frontier(
Y=obj,
ref_point=self.ref_point.unsqueeze(0).expand(
obj.shape[0], self.ref_point.shape[-1]
),
feasibility_mask=feas,
)
self.partitioning = self.p_class(
ref_point=self.ref_point, Y=obj, **self.p_kwargs
)
cell_bounds = self.partitioning.get_hypercell_bounds().to(self.ref_point)
cell_bounds = cell_bounds.view(
2, *self._batch_sample_shape, *cell_bounds.shape[-2:]
)
self.register_buffer("cell_lower_bounds", cell_bounds[0])
self.register_buffer("cell_upper_bounds", cell_bounds[1])