in botorch/models/converter.py [0:0]
def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
"""Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.
Args:
batch_model: The `BatchedMultiOutputGPyTorchModel` to be converted to a
`ModelListGP`.
Returns:
The model converted into a `ModelListGP`.
Example:
>>> train_X = torch.rand(5, 2)
>>> train_Y = torch.rand(5, 2)
>>> batch_gp = SingleTaskGP(train_X, train_Y)
>>> list_gp = batched_to_model_list(batch_gp)
"""
# TODO: Add support for HeteroskedasticSingleTaskGP.
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP is currently not supported."
)
if isinstance(batch_model, MixedSingleTaskGP):
raise NotImplementedError(
"Conversion of MixedSingleTaskGP is currently not supported."
)
input_transform = getattr(batch_model, "input_transform", None)
outcome_transform = getattr(batch_model, "outcome_transform", None)
batch_sd = batch_model.state_dict()
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
batch_state_dict=batch_sd,
input_transform=input_transform,
outcome_transform=outcome_transform,
)
input_bdims = len(batch_model._input_batch_shape)
models = []
for i in range(batch_model._num_outputs):
non_adjusted_batch_sd = {
s: batch_sd[s].clone() for s in non_adjusted_batch_keys
}
adjusted_batch_sd = {
t: (
batch_sd[t].select(input_bdims, i).clone()
if "active_dims" not in t
else batch_sd[t].clone()
)
for t in adjusted_batch_keys
}
sd = {**non_adjusted_batch_sd, **adjusted_batch_sd}
kwargs = {
"train_X": batch_model.train_inputs[0].select(input_bdims, i).clone(),
"train_Y": batch_model.train_targets.select(input_bdims, i)
.clone()
.unsqueeze(-1),
}
if isinstance(batch_model, FixedNoiseGP):
noise_covar = batch_model.likelihood.noise_covar
kwargs["train_Yvar"] = (
noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1)
)
if isinstance(batch_model, SingleTaskMultiFidelityGP):
kwargs.update(batch_model._init_args)
# NOTE: Adding outcome transform to kwargs to avoid the multiple
# values for same kwarg issue with SingleTaskMultiFidelityGP.
if outcome_transform is not None:
octf = outcome_transform.subset_output(idcs=[i])
kwargs["outcome_transform"] = octf
# Update the outcome transform state dict entries.
sd = {
**sd,
**{"outcome_transform." + k: v for k, v in octf.state_dict().items()},
}
else:
kwargs["outcome_transform"] = None
model = batch_model.__class__(input_transform=input_transform, **kwargs)
model.load_state_dict(sd)
models.append(model)
return ModelListGP(*models)