def load_mcmc_samples()

in botorch/models/fully_bayesian.py [0:0]


    def load_mcmc_samples(self, mcmc_samples: Dict[str, Tensor]) -> None:
        r"""Load the MCMC hyperparameter samples into the model.

        This method will be called by `fit_fully_bayesian_model_nuts` when the model
        has been fitted in order to create a batched SingleTaskGP model.
        """
        tkwargs = {"device": self.train_X.device, "dtype": self.train_X.dtype}
        num_mcmc_samples = len(mcmc_samples["mean"])
        batch_shape = torch.Size([num_mcmc_samples])
        self.mean_module = ConstantMean(batch_shape=batch_shape).to(**tkwargs)
        self.covar_module = ScaleKernel(
            base_kernel=MaternKernel(
                ard_num_dims=self.train_X.shape[-1],
                batch_shape=batch_shape,
            ),
            batch_shape=batch_shape,
        ).to(**tkwargs)
        if self.train_Yvar is not None:
            self.likelihood = FixedNoiseGaussianLikelihood(
                noise=self.train_Yvar, batch_shape=batch_shape
            ).to(**tkwargs)
        else:
            self.likelihood = GaussianLikelihood(
                batch_shape=batch_shape,
                noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL),
            ).to(**tkwargs)
            self.likelihood.noise_covar.noise = (
                mcmc_samples["noise"]
                .detach()
                .clone()
                .view(self.likelihood.noise_covar.noise.shape)
                .clamp_min(MIN_INFERRED_NOISE_LEVEL)
                .to(**tkwargs)
            )

        self.covar_module.base_kernel.lengthscale = (
            mcmc_samples["lengthscale"]
            .detach()
            .clone()
            .view(self.covar_module.base_kernel.lengthscale.shape)
            .to(**tkwargs)
        )
        self.covar_module.outputscale = (
            mcmc_samples["outputscale"]
            .detach()
            .clone()
            .view(self.covar_module.outputscale.shape)
            .to(**tkwargs)
        )
        self.mean_module.constant.data = (
            mcmc_samples["mean"]
            .detach()
            .clone()
            .view(self.mean_module.constant.shape)
            .to(**tkwargs)
        )