in botorch/utils/transforms.py [0:0]
def _verify_output_shape(acqf: Any, X: Tensor, output: Tensor) -> bool:
r"""
Performs the output shape checks for `t_batch_mode_transform`. Output shape checks
help in catching the errors due to AcquisitionFunction arguments with erroneous
return shapes before these errors propagate further down the line.
This method checks that the `output` shape matches either the t-batch shape of X
or the `batch_shape` of `acqf.model`.
Args:
acqf: The AcquisitionFunction object being evaluated.
X: The `... x q x d`-dim input tensor with an explicit t-batch.
output: The return value of `acqf.method(X, ...)`.
Returns:
True if `output` has the correct shape, False otherwise.
"""
try:
return (
output.shape == X.shape[:-2]
or (output.shape == torch.Size() and X.shape[:-2] == torch.Size([1]))
or output.shape == acqf.model.batch_shape
# for a batched model, we may unsqueeze a batch dimension in X
# corresponding to the model's batch dim. In that case the
# acquisition function output should replace the right-most
# batch dim of X with the model's batch shape.
or output.shape == X.shape[:-3] + acqf.model.batch_shape
)
except (AttributeError, NotImplementedError):
# acqf does not have model or acqf.model does not define `batch_shape`
warnings.warn(
"Output shape checks failed! Expected output shape to match t-batch shape"
f"of X, but got output with shape {output.shape} for X with shape"
f"{X.shape}. Make sure that this is the intended behavior!",
RuntimeWarning,
)
return True