in botorch/utils/multi_objective/box_decompositions/non_dominated.py [0:0]
def _partition_space(self) -> None:
r"""Partition the non-dominated space into disjoint hypercells.
This method supports an arbitrary number of outcomes, but is
less efficient than `partition_space_2d` for the 2-outcome case.
"""
# The binary parititoning algorithm uses indices the augmented Pareto front.
# n_pareto + 2 x m
aug_pareto_Y_idcs = self._get_augmented_pareto_front_indices()
# Initialize one cell over entire pareto front
cell = torch.zeros(
2, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
)
cell[1] = aug_pareto_Y_idcs.shape[0] - 1
stack = [cell]
# hypercells contains the indices of the (augmented) Pareto front
# that specify that bounds of the each hypercell.
# It is a `2 x num_cells x m`-dim tensor
self.register_buffer(
"hypercells",
torch.empty(
2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
),
)
outcome_idxr = torch.arange(
self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
)
# edge case: empty pareto set
# use a single cell
if self._neg_pareto_Y.shape[-2] == 0:
# 2 x m
cell_bounds_pareto_idcs = aug_pareto_Y_idcs[cell, outcome_idxr]
self.hypercells = torch.cat(
[self.hypercells, cell_bounds_pareto_idcs.unsqueeze(1)], dim=1
)
else:
# Extend Pareto front with the ideal and anti-ideal point
ideal_point = self._neg_pareto_Y.min(dim=0, keepdim=True).values - 1
anti_ideal_point = self._neg_pareto_Y.max(dim=0, keepdim=True).values + 1
# `n_pareto + 2 x m`
aug_pareto_Y = torch.cat(
[ideal_point, self._neg_pareto_Y, anti_ideal_point], dim=0
)
total_volume = (anti_ideal_point - ideal_point).prod()
# Use binary partitioning
while len(stack) > 0:
# The following 3 tensors are all `2 x m`
cell = stack.pop()
cell_bounds_pareto_idcs = aug_pareto_Y_idcs[cell, outcome_idxr]
cell_bounds_pareto_values = aug_pareto_Y[
cell_bounds_pareto_idcs, outcome_idxr
]
# Check cell bounds
# - if cell upper bound is better than Pareto front on all outcomes:
# - accept the cell
# - elif cell lower bound is better than Pareto front on all outcomes:
# - this means the cell overlaps the Pareto front. Divide the cell
# along its longest edge.
if (
(cell_bounds_pareto_values[1] <= self._neg_pareto_Y)
.any(dim=1)
.all()
):
# Cell is entirely non-dominated
self.hypercells = torch.cat(
[self.hypercells, cell_bounds_pareto_idcs.unsqueeze(1)], dim=1
)
elif (
(cell_bounds_pareto_values[0] <= self._neg_pareto_Y)
.any(dim=1)
.all()
):
# The cell overlaps the pareto front
# compute the distance (in integer indices)
# This has shape `m`
idx_dist = cell[1] - cell[0]
any_not_adjacent = (idx_dist > 1).any()
cell_volume = (
(cell_bounds_pareto_values[1] - cell_bounds_pareto_values[0])
.prod(dim=-1)
.item()
)
# Only divide a cell when it is not composed of adjacent indices
# and the fraction of total volume is above the approximation
# threshold fraction
if (
any_not_adjacent
and ((cell_volume / total_volume) > self.alpha).all()
):
# Divide the test cell over its largest dimension
# largest (by index length)
length, longest_dim = torch.max(idx_dist, dim=0)
length = length.item()
longest_dim = longest_dim.item()
new_length1 = int(round(length / 2.0))
new_length2 = length - new_length1
# Store divided cells
# cell 1: subtract new_length1 from the upper bound of the cell
# cell 2: add new_length2 to the lower bound of the cell
for bound_idx, length_delta in (
(1, -new_length1),
(0, new_length2),
):
new_cell = cell.clone()
new_cell[bound_idx, longest_dim] += length_delta
stack.append(new_cell)