def forward()

in botorch/models/higher_order_gp.py [0:0]


    def forward(self, X: Tensor) -> MultivariateNormal:
        if self.training:
            X = self.transform_inputs(X)

        covariance_list = []
        covariance_list.append(self.covar_modules[0](X))

        for cm, param in zip(self.covar_modules[1:], self.latent_parameters):
            if not self.training:
                with torch.no_grad():
                    covariance_list.append(cm(param))
            else:
                covariance_list.append(cm(param))

        # check batch_shapes
        if covariance_list[0].batch_shape != covariance_list[1].batch_shape:
            for i in range(1, len(covariance_list)):
                cm = covariance_list[i]
                covariance_list[i] = BatchRepeatLazyTensor(
                    cm, covariance_list[0].batch_shape
                )
        kronecker_covariance = KroneckerProductLazyTensor(*covariance_list)

        # TODO: expand options for the mean module via batch shaping?
        mean = torch.zeros(
            *covariance_list[0].batch_shape,
            kronecker_covariance.shape[-1],
            device=kronecker_covariance.device,
            dtype=kronecker_covariance.dtype,
        )
        return MultivariateNormal(mean, kronecker_covariance)