def _verify_output_shape()

in botorch/utils/transforms.py [0:0]


def _verify_output_shape(acqf: Any, X: Tensor, output: Tensor) -> bool:
    r"""
    Performs the output shape checks for `t_batch_mode_transform`. Output shape checks
    help in catching the errors due to AcquisitionFunction arguments with erroneous
    return shapes before these errors propagate further down the line.

    This method checks that the `output` shape matches either the t-batch shape of X
    or the `batch_shape` of `acqf.model`.

    Args:
        acqf: The AcquisitionFunction object being evaluated.
        X: The `... x q x d`-dim input tensor with an explicit t-batch.
        output: The return value of `acqf.method(X, ...)`.

    Returns:
        True if `output` has the correct shape, False otherwise.
    """
    try:
        return (
            output.shape == X.shape[:-2]
            or (output.shape == torch.Size() and X.shape[:-2] == torch.Size([1]))
            or output.shape == acqf.model.batch_shape
            # for a batched model, we may unsqueeze a batch dimension in X
            # corresponding to the model's batch dim. In that case the
            # acquisition function output should replace the right-most
            # batch dim of X with the model's batch shape.
            or output.shape == X.shape[:-3] + acqf.model.batch_shape
        )
    except (AttributeError, NotImplementedError):
        # acqf does not have model or acqf.model does not define `batch_shape`
        warnings.warn(
            "Output shape checks failed! Expected output shape to match t-batch shape"
            f"of X, but got output with shape {output.shape} for X with shape"
            f"{X.shape}. Make sure that this is the intended behavior!",
            RuntimeWarning,
        )
        return True