def get_partition_bounds()

in botorch/utils/multi_objective/box_decompositions/utils.py [0:0]


def get_partition_bounds(Z: Tensor, U: Tensor, ref_point: Tensor) -> Tensor:
    r"""Get the cell bounds given the local upper bounds and the defining points.

    This implements Equation 2 in [Lacour17]_.

    Args:
        Z: A `n x m x m`-dim tensor containing the defining points. The first
            dimension corresponds to u_idx, the second dimension corresponds to j,
            and Z[u_idx, j] is the set of definining points Z^j(u) where
            u = U[u_idx].
        U: A `n x m`-dim tensor containing the local upper bounds.
        ref_point: A `m`-dim tensor containing the reference point.

    Returns:
        A `2 x num_cells x m`-dim tensor containing the lower and upper vertices
            bounding each hypercell.
    """
    bounds = torch.empty(2, U.shape[0], U.shape[-1], dtype=U.dtype, device=U.device)
    for u_idx in range(U.shape[0]):
        # z_1^1(u)
        bounds[0, u_idx, 0] = Z[u_idx, 0, 0]
        # z_1^r(u)
        bounds[1, u_idx, 0] = ref_point[0]
        for j in range(1, U.shape[-1]):
            bounds[0, u_idx, j] = Z[u_idx, :j, j].max()
            bounds[1, u_idx, j] = U[u_idx, j]
    # remove empty partitions
    # Note: the equality will evaluate as True if the lower and upper bound
    # are both (-inf), which could happen if the reference point is -inf.
    empty = (bounds[1] <= bounds[0]).any(dim=-1)
    return bounds[:, ~empty]