in botorch/utils/multi_objective/box_decompositions/non_dominated.py [0:0]
def _get_hypercell_bounds(self, aug_pareto_Y: Tensor) -> Tensor:
r"""Get the bounds of each hypercell in the decomposition.
Args:
aug_pareto_Y: A `n_pareto + 2 x m`-dim tensor containing
the augmented Pareto front.
Returns:
A `2 x (batch_shape) x num_cells x m`-dim tensor containing the
lower and upper vertices bounding each hypercell.
"""
num_cells = self.hypercells.shape[-2]
cells_times_outcomes = num_cells * self.num_outcomes
outcome_idxr = (
torch.arange(self.num_outcomes, dtype=torch.long, device=self._neg_Y.device)
.repeat(num_cells)
.view(
*(1 for _ in self.hypercells.shape[:-2]),
cells_times_outcomes,
)
.expand(*self.hypercells.shape[:-2], cells_times_outcomes)
)
# this tensor is 2 x (num_cells * m) x 2
# the batch dim corresponds to lower/upper bound
cell_bounds_idxr = torch.stack(
[
self.hypercells.view(*self.hypercells.shape[:-2], -1),
outcome_idxr,
],
dim=-1,
).view(2, -1, 2)
if len(self.batch_shape) > 0:
# TODO: support multiple batch dimensions here
batch_idxr = (
torch.arange(
self.batch_shape[0], dtype=torch.long, device=self._neg_Y.device
)
.unsqueeze(1)
.expand(-1, cells_times_outcomes)
.reshape(1, -1, 1)
.expand(2, -1, 1)
)
cell_bounds_idxr = torch.cat([batch_idxr, cell_bounds_idxr], dim=-1)
cell_bounds_values = aug_pareto_Y[
cell_bounds_idxr.chunk(cell_bounds_idxr.shape[-1], dim=-1)
]
view_shape = (2, *self.batch_shape, num_cells, self.num_outcomes)
return cell_bounds_values.view(view_shape)