def get_mvar_set_gpu()

in botorch/acquisition/multi_objective/multi_output_risk_measures.py [0:0]


    def get_mvar_set_gpu(self, Y: Tensor) -> Tensor:
        r"""Find MVaR set based on the definition in [Prekopa2012MVaR]_.

        NOTE: This is much faster on GPU than the alternative but it scales very poorly
        on CPU as `n_w` increases. This should be preferred if a GPU is available or
        when `n_w <= 64`. In addition, this supports `m >= 2` outcomes (vs `m = 2` for
        the CPU version) and it should be used if `m > 2`.

        This first calculates the CDF for each point on the extended domain of the
        random variable (the grid defined by the given samples), then takes the
        values with CDF equal to (rounded if necessary) `alpha`. The non-dominated
        subset of these form the MVaR set.

        Args:
            Y: A `batch x n_w x m`-dim tensor of observations.

        Returns:
            A `batch` length list of `k x m`-dim tensor of MVaR values, where `k`
            depends on the corresponding batch inputs. Note that MVaR values in general
            are not in-sample points.
        """
        if Y.dim() == 2:
            Y = Y.unsqueeze(0)
        batch, m = Y.shape[0], Y.shape[-1]
        # Note that points in MVaR are bounded from above by the
        # independent VaR of each objective. Hence, we only need to
        # consider the unique outcomes that are less than or equal to
        # the VaR of the independent objectives
        var_alpha_idx = ceil(self.alpha * self.n_w) - 1
        n_points = Y.shape[-2] - var_alpha_idx
        Y_sorted = Y.topk(n_points, dim=-2, largest=False).values
        # `y_grid` is the grid formed by all inputs in each batch.
        if m == 2:
            # This is significantly faster but only works with m=2.
            y_grid = torch.stack(
                [
                    Y_sorted[..., 0].repeat_interleave(repeats=n_points, dim=-1),
                    Y_sorted[..., 1].repeat(1, n_points),
                ],
                dim=-1,
            )
        else:
            y_grid = torch.stack(
                [
                    torch.stack(
                        torch.meshgrid([Y_sorted[b, :, i] for i in range(m)]),
                        dim=-1,
                    ).view(-1, m)
                    for b in range(batch)
                ],
                dim=0,
            )
        # Get the non-normalized CDF.
        cdf = (Y.unsqueeze(-2) >= y_grid.unsqueeze(-3)).all(dim=-1).sum(dim=-2)
        # Get the alpha level points
        alpha_count = ceil(self.alpha * self.n_w)
        # NOTE: Need to loop here since mvar may have different shapes.
        mvar = []
        for b in range(batch):
            alpha_level_points = y_grid[b][cdf[b] == alpha_count]
            # If there are no exact alpha level points, get the smallest alpha' > alpha
            # and find the corresponding alpha level indices.
            if alpha_level_points.numel() == 0:
                min_greater_than_alpha = cdf[b][cdf[b] > alpha_count].min()
                alpha_level_points = y_grid[b][cdf[b] == min_greater_than_alpha]
            # MVaR is the non-dominated subset of alpha level points.
            if self.filter_dominated:
                mask = is_non_dominated(alpha_level_points)
                mvar.append(alpha_level_points[mask])
            else:
                mvar.append(alpha_level_points)
        return mvar