in botorch/models/higher_order_gp.py [0:0]
def forward(self, X: Tensor) -> MultivariateNormal:
if self.training:
X = self.transform_inputs(X)
covariance_list = []
covariance_list.append(self.covar_modules[0](X))
for cm, param in zip(self.covar_modules[1:], self.latent_parameters):
if not self.training:
with torch.no_grad():
covariance_list.append(cm(param))
else:
covariance_list.append(cm(param))
# check batch_shapes
if covariance_list[0].batch_shape != covariance_list[1].batch_shape:
for i in range(1, len(covariance_list)):
cm = covariance_list[i]
covariance_list[i] = BatchRepeatLazyTensor(
cm, covariance_list[0].batch_shape
)
kronecker_covariance = KroneckerProductLazyTensor(*covariance_list)
# TODO: expand options for the mean module via batch shaping?
mean = torch.zeros(
*covariance_list[0].batch_shape,
kronecker_covariance.shape[-1],
device=kronecker_covariance.device,
dtype=kronecker_covariance.dtype,
)
return MultivariateNormal(mean, kronecker_covariance)