def _partition_space()

in botorch/utils/multi_objective/box_decompositions/non_dominated.py [0:0]


    def _partition_space(self) -> None:
        r"""Partition the non-dominated space into disjoint hypercells.

        This method supports an arbitrary number of outcomes, but is
        less efficient than `partition_space_2d` for the 2-outcome case.
        """
        # The binary parititoning algorithm uses indices the augmented Pareto front.
        # n_pareto + 2 x m
        aug_pareto_Y_idcs = self._get_augmented_pareto_front_indices()

        # Initialize one cell over entire pareto front
        cell = torch.zeros(
            2, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
        )
        cell[1] = aug_pareto_Y_idcs.shape[0] - 1
        stack = [cell]

        # hypercells contains the indices of the (augmented) Pareto front
        # that specify that bounds of the each hypercell.
        # It is a `2 x num_cells x m`-dim tensor
        self.register_buffer(
            "hypercells",
            torch.empty(
                2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
            ),
        )
        outcome_idxr = torch.arange(
            self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
        )

        # edge case: empty pareto set
        # use a single cell
        if self._neg_pareto_Y.shape[-2] == 0:
            # 2 x m
            cell_bounds_pareto_idcs = aug_pareto_Y_idcs[cell, outcome_idxr]
            self.hypercells = torch.cat(
                [self.hypercells, cell_bounds_pareto_idcs.unsqueeze(1)], dim=1
            )
        else:
            # Extend Pareto front with the ideal and anti-ideal point
            ideal_point = self._neg_pareto_Y.min(dim=0, keepdim=True).values - 1
            anti_ideal_point = self._neg_pareto_Y.max(dim=0, keepdim=True).values + 1
            # `n_pareto + 2 x m`
            aug_pareto_Y = torch.cat(
                [ideal_point, self._neg_pareto_Y, anti_ideal_point], dim=0
            )

            total_volume = (anti_ideal_point - ideal_point).prod()

            # Use binary partitioning
            while len(stack) > 0:
                # The following 3 tensors are all `2 x m`
                cell = stack.pop()
                cell_bounds_pareto_idcs = aug_pareto_Y_idcs[cell, outcome_idxr]
                cell_bounds_pareto_values = aug_pareto_Y[
                    cell_bounds_pareto_idcs, outcome_idxr
                ]
                # Check cell bounds
                # - if cell upper bound is better than Pareto front on all outcomes:
                #   - accept the cell
                # - elif cell lower bound is better than Pareto front on all outcomes:
                #   - this means the cell overlaps the Pareto front. Divide the cell
                #     along its longest edge.
                if (
                    (cell_bounds_pareto_values[1] <= self._neg_pareto_Y)
                    .any(dim=1)
                    .all()
                ):
                    # Cell is entirely non-dominated
                    self.hypercells = torch.cat(
                        [self.hypercells, cell_bounds_pareto_idcs.unsqueeze(1)], dim=1
                    )
                elif (
                    (cell_bounds_pareto_values[0] <= self._neg_pareto_Y)
                    .any(dim=1)
                    .all()
                ):
                    # The cell overlaps the pareto front
                    # compute the distance (in integer indices)
                    # This has shape `m`
                    idx_dist = cell[1] - cell[0]

                    any_not_adjacent = (idx_dist > 1).any()
                    cell_volume = (
                        (cell_bounds_pareto_values[1] - cell_bounds_pareto_values[0])
                        .prod(dim=-1)
                        .item()
                    )

                    # Only divide a cell when it is not composed of adjacent indices
                    # and the fraction of total volume is above the approximation
                    # threshold fraction
                    if (
                        any_not_adjacent
                        and ((cell_volume / total_volume) > self.alpha).all()
                    ):
                        # Divide the test cell over its largest dimension
                        # largest (by index length)
                        length, longest_dim = torch.max(idx_dist, dim=0)
                        length = length.item()
                        longest_dim = longest_dim.item()

                        new_length1 = int(round(length / 2.0))
                        new_length2 = length - new_length1

                        # Store divided cells
                        # cell 1: subtract new_length1 from the upper bound of the cell
                        # cell 2: add new_length2 to the lower bound of the cell
                        for bound_idx, length_delta in (
                            (1, -new_length1),
                            (0, new_length2),
                        ):
                            new_cell = cell.clone()
                            new_cell[bound_idx, longest_dim] += length_delta
                            stack.append(new_cell)