in botorch/models/converter.py [0:0]
def _check_compatibility(models: ModelListGP) -> None:
"""Check if a ModelListGP can be converted."""
# Check that all submodules are of the same type.
for modn, mod in models[0].named_modules():
mcls = mod.__class__
if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
raise UnsupportedError(
"Sub-modules must be of the same type across models."
)
# Check that each model is a BatchedMultiOutputGPyTorchModel.
if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
raise UnsupportedError(
"All models must be of type BatchedMultiOutputGPyTorchModel."
)
# TODO: Add support for HeteroskedasticSingleTaskGP.
if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
)
# TODO: Add support for custom likelihoods.
if any(getattr(m, "_is_custom_likelihood", False) for m in models):
raise NotImplementedError(
"Conversion of models with custom likelihoods is currently unsupported."
)
# TODO: Add support for outcome transforms.
if any(getattr(m, "outcome_transform", None) is not None for m in models):
raise UnsupportedError(
"Conversion of models with outcome transforms is currently unsupported."
)
# check that each model is single-output
if not all(m._num_outputs == 1 for m in models):
raise UnsupportedError("All models must be single-output.")
# check that training inputs are the same
if not all(
torch.equal(ti, tj)
for m in models[1:]
for ti, tj in zip(models[0].train_inputs, m.train_inputs)
):
raise UnsupportedError("training inputs must agree for all sub-models.")
# check that there are no batched input transforms
default_size = torch.Size([])
for m in models:
if hasattr(m, "input_transform"):
if (
m.input_transform is not None
and len(getattr(m.input_transform, "batch_shape", default_size)) != 0
):
raise UnsupportedError("Batched input_transforms are not supported.")
# check that all models have the same input transforms
if any(hasattr(m, "input_transform") for m in models):
if not all(
m.input_transform.equals(models[0].input_transform) for m in models[1:]
):
raise UnsupportedError("All models must have the same input_transforms.")