in botorch/acquisition/multi_objective/multi_output_risk_measures.py [0:0]
def get_mvar_set_gpu(self, Y: Tensor) -> Tensor:
r"""Find MVaR set based on the definition in [Prekopa2012MVaR]_.
NOTE: This is much faster on GPU than the alternative but it scales very poorly
on CPU as `n_w` increases. This should be preferred if a GPU is available or
when `n_w <= 64`. In addition, this supports `m >= 2` outcomes (vs `m = 2` for
the CPU version) and it should be used if `m > 2`.
This first calculates the CDF for each point on the extended domain of the
random variable (the grid defined by the given samples), then takes the
values with CDF equal to (rounded if necessary) `alpha`. The non-dominated
subset of these form the MVaR set.
Args:
Y: A `batch x n_w x m`-dim tensor of observations.
Returns:
A `batch` length list of `k x m`-dim tensor of MVaR values, where `k`
depends on the corresponding batch inputs. Note that MVaR values in general
are not in-sample points.
"""
if Y.dim() == 2:
Y = Y.unsqueeze(0)
batch, m = Y.shape[0], Y.shape[-1]
# Note that points in MVaR are bounded from above by the
# independent VaR of each objective. Hence, we only need to
# consider the unique outcomes that are less than or equal to
# the VaR of the independent objectives
var_alpha_idx = ceil(self.alpha * self.n_w) - 1
n_points = Y.shape[-2] - var_alpha_idx
Y_sorted = Y.topk(n_points, dim=-2, largest=False).values
# `y_grid` is the grid formed by all inputs in each batch.
if m == 2:
# This is significantly faster but only works with m=2.
y_grid = torch.stack(
[
Y_sorted[..., 0].repeat_interleave(repeats=n_points, dim=-1),
Y_sorted[..., 1].repeat(1, n_points),
],
dim=-1,
)
else:
y_grid = torch.stack(
[
torch.stack(
torch.meshgrid([Y_sorted[b, :, i] for i in range(m)]),
dim=-1,
).view(-1, m)
for b in range(batch)
],
dim=0,
)
# Get the non-normalized CDF.
cdf = (Y.unsqueeze(-2) >= y_grid.unsqueeze(-3)).all(dim=-1).sum(dim=-2)
# Get the alpha level points
alpha_count = ceil(self.alpha * self.n_w)
# NOTE: Need to loop here since mvar may have different shapes.
mvar = []
for b in range(batch):
alpha_level_points = y_grid[b][cdf[b] == alpha_count]
# If there are no exact alpha level points, get the smallest alpha' > alpha
# and find the corresponding alpha level indices.
if alpha_level_points.numel() == 0:
min_greater_than_alpha = cdf[b][cdf[b] > alpha_count].min()
alpha_level_points = y_grid[b][cdf[b] == min_greater_than_alpha]
# MVaR is the non-dominated subset of alpha level points.
if self.filter_dominated:
mask = is_non_dominated(alpha_level_points)
mvar.append(alpha_level_points[mask])
else:
mvar.append(alpha_level_points)
return mvar