def _get_hypercell_bounds()

in botorch/utils/multi_objective/box_decompositions/non_dominated.py [0:0]


    def _get_hypercell_bounds(self, aug_pareto_Y: Tensor) -> Tensor:
        r"""Get the bounds of each hypercell in the decomposition.

        Args:
            aug_pareto_Y: A `n_pareto + 2 x m`-dim tensor containing
            the augmented Pareto front.

        Returns:
            A `2 x (batch_shape) x num_cells x m`-dim tensor containing the
                lower and upper vertices bounding each hypercell.
        """
        num_cells = self.hypercells.shape[-2]
        cells_times_outcomes = num_cells * self.num_outcomes
        outcome_idxr = (
            torch.arange(self.num_outcomes, dtype=torch.long, device=self._neg_Y.device)
            .repeat(num_cells)
            .view(
                *(1 for _ in self.hypercells.shape[:-2]),
                cells_times_outcomes,
            )
            .expand(*self.hypercells.shape[:-2], cells_times_outcomes)
        )

        # this tensor is 2 x (num_cells * m) x 2
        # the batch dim corresponds to lower/upper bound
        cell_bounds_idxr = torch.stack(
            [
                self.hypercells.view(*self.hypercells.shape[:-2], -1),
                outcome_idxr,
            ],
            dim=-1,
        ).view(2, -1, 2)
        if len(self.batch_shape) > 0:
            # TODO: support multiple batch dimensions here
            batch_idxr = (
                torch.arange(
                    self.batch_shape[0], dtype=torch.long, device=self._neg_Y.device
                )
                .unsqueeze(1)
                .expand(-1, cells_times_outcomes)
                .reshape(1, -1, 1)
                .expand(2, -1, 1)
            )
            cell_bounds_idxr = torch.cat([batch_idxr, cell_bounds_idxr], dim=-1)

        cell_bounds_values = aug_pareto_Y[
            cell_bounds_idxr.chunk(cell_bounds_idxr.shape[-1], dim=-1)
        ]
        view_shape = (2, *self.batch_shape, num_cells, self.num_outcomes)
        return cell_bounds_values.view(view_shape)