def batched_to_model_list()

in botorch/models/converter.py [0:0]


def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
    """Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.

    Args:
        batch_model: The `BatchedMultiOutputGPyTorchModel` to be converted to a
            `ModelListGP`.

    Returns:
        The model converted into a `ModelListGP`.

    Example:
        >>> train_X = torch.rand(5, 2)
        >>> train_Y = torch.rand(5, 2)
        >>> batch_gp = SingleTaskGP(train_X, train_Y)
        >>> list_gp = batched_to_model_list(batch_gp)
    """
    # TODO: Add support for HeteroskedasticSingleTaskGP.
    if isinstance(batch_model, HeteroskedasticSingleTaskGP):
        raise NotImplementedError(
            "Conversion of HeteroskedasticSingleTaskGP is currently not supported."
        )
    if isinstance(batch_model, MixedSingleTaskGP):
        raise NotImplementedError(
            "Conversion of MixedSingleTaskGP is currently not supported."
        )
    input_transform = getattr(batch_model, "input_transform", None)
    outcome_transform = getattr(batch_model, "outcome_transform", None)
    batch_sd = batch_model.state_dict()

    adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
        batch_state_dict=batch_sd,
        input_transform=input_transform,
        outcome_transform=outcome_transform,
    )
    input_bdims = len(batch_model._input_batch_shape)

    models = []

    for i in range(batch_model._num_outputs):
        non_adjusted_batch_sd = {
            s: batch_sd[s].clone() for s in non_adjusted_batch_keys
        }
        adjusted_batch_sd = {
            t: (
                batch_sd[t].select(input_bdims, i).clone()
                if "active_dims" not in t
                else batch_sd[t].clone()
            )
            for t in adjusted_batch_keys
        }
        sd = {**non_adjusted_batch_sd, **adjusted_batch_sd}
        kwargs = {
            "train_X": batch_model.train_inputs[0].select(input_bdims, i).clone(),
            "train_Y": batch_model.train_targets.select(input_bdims, i)
            .clone()
            .unsqueeze(-1),
        }
        if isinstance(batch_model, FixedNoiseGP):
            noise_covar = batch_model.likelihood.noise_covar
            kwargs["train_Yvar"] = (
                noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1)
            )
        if isinstance(batch_model, SingleTaskMultiFidelityGP):
            kwargs.update(batch_model._init_args)
        # NOTE: Adding outcome transform to kwargs to avoid the multiple
        # values for same kwarg issue with SingleTaskMultiFidelityGP.
        if outcome_transform is not None:
            octf = outcome_transform.subset_output(idcs=[i])
            kwargs["outcome_transform"] = octf
            # Update the outcome transform state dict entries.
            sd = {
                **sd,
                **{"outcome_transform." + k: v for k, v in octf.state_dict().items()},
            }
        else:
            kwargs["outcome_transform"] = None
        model = batch_model.__class__(input_transform=input_transform, **kwargs)
        model.load_state_dict(sd)
        models.append(model)

    return ModelListGP(*models)