def model_list_to_batched()

in botorch/models/converter.py [0:0]


def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorchModel:
    """Convert a ModelListGP to a BatchedMultiOutputGPyTorchModel.

    Args:
        model_list: The `ModelListGP` to be converted to the appropriate
            `BatchedMultiOutputGPyTorchModel`. All sub-models must be of the same
            type and have the shape (batch shape and number of training inputs).

    Returns:
        The model converted into a `BatchedMultiOutputGPyTorchModel`.

    Example:
        >>> list_gp = ModelListGP(gp1, gp2)
        >>> batch_gp = model_list_to_batched(list_gp)
    """
    models = model_list.models
    _check_compatibility(models)

    # if the list has only one model, we can just return a copy of that
    if len(models) == 1:
        return deepcopy(models[0])

    # construct inputs
    train_X = deepcopy(models[0].train_inputs[0])
    train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1)
    kwargs = {"train_X": train_X, "train_Y": train_Y}
    if isinstance(models[0], FixedNoiseGP):
        kwargs["train_Yvar"] = torch.stack(
            [m.likelihood.noise_covar.noise.clone() for m in models], dim=-1
        )
    if isinstance(models[0], SingleTaskMultiFidelityGP):
        init_args = models[0]._init_args
        if not all(
            v == m._init_args[k] for m in models[1:] for k, v in init_args.items()
        ):
            raise UnsupportedError("All models must have the same fidelity parameters.")
        kwargs.update(init_args)

    # construct the batched GP model
    input_transform = getattr(models[0], "input_transform", None)
    batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
    adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
        batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
    )
    input_batch_dims = len(models[0]._input_batch_shape)

    # ensure scalars agree (TODO: Allow different priors for different outputs)
    for n in non_adjusted_batch_keys:
        v0 = _get_module(models[0], n)
        if not all(torch.equal(_get_module(m, n), v0) for m in models[1:]):
            raise UnsupportedError("All scalars must have the same value.")

    # ensure dimensions of all tensors agree
    for n in adjusted_batch_keys:
        shape0 = _get_module(models[0], n).shape
        if not all(_get_module(m, n).shape == shape0 for m in models[1:]):
            raise UnsupportedError("All tensors must have the same shape.")

    # now construct the batched state dict
    non_adjusted_batch_state_dict = {
        s: p.clone()
        for s, p in models[0].state_dict().items()
        if s in non_adjusted_batch_keys
    }
    adjusted_batch_state_dict = {
        t: (
            torch.stack(
                [m.state_dict()[t].clone() for m in models], dim=input_batch_dims
            )
            if "active_dims" not in t
            else models[0].state_dict()[t].clone()
        )
        for t in adjusted_batch_keys
    }
    batch_state_dict = {**non_adjusted_batch_state_dict, **adjusted_batch_state_dict}

    # load the state dict into the new model
    batch_gp.load_state_dict(batch_state_dict)

    return batch_gp