def _set_cell_bounds()

in botorch/acquisition/multi_objective/monte_carlo.py [0:0]


    def _set_cell_bounds(self, num_new_points: int) -> None:
        r"""Compute the box decomposition under each posterior sample.

        Args:
            num_new_points: The number of new points (beyond the points
                in X_baseline) that were used in the previous box decomposition.
                In the first box decomposition, this should be the number of points
                in X_baseline.
        """
        feas = None
        if self.X_baseline.shape[0] > 0:
            with torch.no_grad():
                posterior = self.model.posterior(self.X_baseline)
            # Reset sampler, accounting for possible one-to-many transform.
            self.q_in = -1
            n_w = posterior.event_shape[-2] // self.X_baseline.shape[-2]
            self._set_sampler(q_in=num_new_points * n_w, posterior=posterior)
            # set base_sampler
            self.base_sampler.register_buffer(
                "base_samples", self.sampler.base_samples.detach().clone()
            )

            samples = self.base_sampler(posterior)
            # cache posterior
            if self._cache_root:
                self._cache_root_decomposition(posterior=posterior)
            obj = self.objective(samples, X=self.X_baseline)
            if self.constraints is not None:
                feas = torch.stack(
                    [c(samples) <= 0 for c in self.constraints], dim=0
                ).all(dim=0)
        else:
            obj = torch.empty(
                *self.sampler._sample_shape,
                0,
                self.ref_point.shape[-1],
                dtype=self.ref_point.dtype,
                device=self.ref_point.device,
            )
        self._batch_sample_shape = obj.shape[:-2]
        # collapse batch dimensions
        # use numel() rather than view(-1) to handle case of no baseline points
        new_batch_shape = self._batch_sample_shape.numel()
        obj = obj.view(new_batch_shape, *obj.shape[-2:])
        if self.constraints is not None and feas is not None:
            feas = feas.view(new_batch_shape, *feas.shape[-1:])

        if self.partitioning is None and not self.incremental_nehvi:
            self._compute_initial_hvs(obj=obj, feas=feas)
        if self.ref_point.shape[-1] > 2:
            # the partitioning algorithms run faster on the CPU
            # due to advanced indexing
            ref_point_cpu = self.ref_point.cpu()
            obj_cpu = obj.cpu()
            if self.constraints is not None and feas is not None:
                feas_cpu = feas.cpu()
                obj_cpu = [obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0])]
            partitionings = []
            for sample in obj_cpu:
                partitioning = self.p_class(
                    ref_point=ref_point_cpu, Y=sample, **self.p_kwargs
                )
                partitionings.append(partitioning)
            self.partitioning = BoxDecompositionList(*partitionings)
        else:
            # use batched partitioning
            obj = _pad_batch_pareto_frontier(
                Y=obj,
                ref_point=self.ref_point.unsqueeze(0).expand(
                    obj.shape[0], self.ref_point.shape[-1]
                ),
                feasibility_mask=feas,
            )
            self.partitioning = self.p_class(
                ref_point=self.ref_point, Y=obj, **self.p_kwargs
            )
        cell_bounds = self.partitioning.get_hypercell_bounds().to(self.ref_point)
        cell_bounds = cell_bounds.view(
            2, *self._batch_sample_shape, *cell_bounds.shape[-2:]
        )
        self.register_buffer("cell_lower_bounds", cell_bounds[0])
        self.register_buffer("cell_upper_bounds", cell_bounds[1])