in ax/models/torch/botorch_modular/utils.py [0:0]
def use_model_list(Xs: List[Tensor], botorch_model_class: Type[Model]) -> bool:
if issubclass(botorch_model_class, MultiTaskGP):
# We currently always wrap multi-task models into `ModelListGP`.
return True
if len(Xs) == 1:
# Just one outcome, can use single model.
return False
if issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all(
torch.equal(Xs[0], X) for X in Xs[1:]
):
# Single model, batched multi-output case.
return False
# If there are multiple Xs and they are not all equal, we
# use `ListSurrogate` and `ModelListGP`.
return True