in botorch/models/converter.py [0:0]
def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorchModel:
"""Convert a ModelListGP to a BatchedMultiOutputGPyTorchModel.
Args:
model_list: The `ModelListGP` to be converted to the appropriate
`BatchedMultiOutputGPyTorchModel`. All sub-models must be of the same
type and have the shape (batch shape and number of training inputs).
Returns:
The model converted into a `BatchedMultiOutputGPyTorchModel`.
Example:
>>> list_gp = ModelListGP(gp1, gp2)
>>> batch_gp = model_list_to_batched(list_gp)
"""
models = model_list.models
_check_compatibility(models)
# if the list has only one model, we can just return a copy of that
if len(models) == 1:
return deepcopy(models[0])
# construct inputs
train_X = deepcopy(models[0].train_inputs[0])
train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1)
kwargs = {"train_X": train_X, "train_Y": train_Y}
if isinstance(models[0], FixedNoiseGP):
kwargs["train_Yvar"] = torch.stack(
[m.likelihood.noise_covar.noise.clone() for m in models], dim=-1
)
if isinstance(models[0], SingleTaskMultiFidelityGP):
init_args = models[0]._init_args
if not all(
v == m._init_args[k] for m in models[1:] for k, v in init_args.items()
):
raise UnsupportedError("All models must have the same fidelity parameters.")
kwargs.update(init_args)
# construct the batched GP model
input_transform = getattr(models[0], "input_transform", None)
batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
)
input_batch_dims = len(models[0]._input_batch_shape)
# ensure scalars agree (TODO: Allow different priors for different outputs)
for n in non_adjusted_batch_keys:
v0 = _get_module(models[0], n)
if not all(torch.equal(_get_module(m, n), v0) for m in models[1:]):
raise UnsupportedError("All scalars must have the same value.")
# ensure dimensions of all tensors agree
for n in adjusted_batch_keys:
shape0 = _get_module(models[0], n).shape
if not all(_get_module(m, n).shape == shape0 for m in models[1:]):
raise UnsupportedError("All tensors must have the same shape.")
# now construct the batched state dict
non_adjusted_batch_state_dict = {
s: p.clone()
for s, p in models[0].state_dict().items()
if s in non_adjusted_batch_keys
}
adjusted_batch_state_dict = {
t: (
torch.stack(
[m.state_dict()[t].clone() for m in models], dim=input_batch_dims
)
if "active_dims" not in t
else models[0].state_dict()[t].clone()
)
for t in adjusted_batch_keys
}
batch_state_dict = {**non_adjusted_batch_state_dict, **adjusted_batch_state_dict}
# load the state dict into the new model
batch_gp.load_state_dict(batch_state_dict)
return batch_gp