def get_model_status()

in src/peft/peft_model.py [0:0]


def get_model_status(model: torch.nn.Module) -> TunerModelStatus:
    """Get the status of tuners of the model.

    This function returns a `TunerModelStatus` dataclass instance, which contains the following attributes:

    - `base_model_type` (`str`):
       The type of the base model, e.g. `T5Model`.
    - `adapter_model_type` (`str`):
       The type of the adapter model, e.g. `LoraModel`.
    - `peft_types` (`dict[str, str]`):
       The mapping of adapter name to adapter type, e.g. `{"default": "LORA"}`.
    - `trainable_params` (`int`):
       The number of trainable parameters in the model.
    - `total_params` (`int`):
       The total number of parameters in the model.
    - `num_adapter_layers` (`int`):
       The number of adapter layers in the model.
    - `enabled` (`bool`, `Literal["irregular"]`):
       Whether all adapter layers are enabled. If some are enabled and some are not, this will be `"irregular"`. This
       means that your model is in an inconsistent state and might not work as expected.
    - `active_adapters` (`list[str]`, `Literal["irregular"]`):
       The names of the active adapters. If the active adapters are not consistent across all layers, this will be
       `"irregular"`, which means that your model is in an inconsistent state and might not work as expected.
    - `merged_adapters` (`list[str]`, `Literal["irregular"]`):
       The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be
       `"irregular"`, which means that your model is in an inconsistent state and might not work as expected.
    - `requires_grad` (`dict[str, bool | Literal["irregular"]]`):
       Whether for the given adapter, all adapter layers have `requires_grad` set to `True` or `False`. If there is a
       mix, this will be set to `"irregular"`, which means that your model is in an inconsistent state and might not
       work as expected.
    - `available_adapters` (`list[str]`):
       The names of the available adapters, e.g. `["default"]`.
    - `devices` (`dict[str, list[str]]`):
       The devices where the parameters of the given adapter are stored, e.g. `["cuda"]`.

    Args:
        model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]):
            The model to get the adapter layer status from.

    Returns:
        `peft.peft_model.TunerModelStatus`:
            A dataclass containing the status of the model.

    """
    if isinstance(model, PeftModel):
        if not isinstance(model.base_model, BaseTuner):
            raise TypeError(
                "get_model_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not "
                "supported."
            )
        base_model_type = model.get_base_model().__class__.__name__
        trainable_params, total_params = model.get_nb_trainable_parameters()
        base_model = model.base_model
        peft_types = {key: str(config.peft_type).partition(".")[-1] for key, config in base_model.peft_config.items()}
        adapter_model_type = base_model.__class__.__name__
    elif isinstance(model, PreTrainedModel):
        base_model_type = model.__class__.__name__
        trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model)
        base_model = model
        peft_types = {}
        adapter_model_type = "None"
    else:
        base_model_type = "other"
        trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model)
        base_model = model
        peft_types = {}
        adapter_model_type = "None"

    layer_status = get_layer_status(model)
    num_adapter_layers = len(layer_status)

    enabled_set: set[bool] = {status.enabled for status in layer_status}  # must be {True}, {False}, or {True, False}
    enabled: bool | Literal["irregular"]
    if len(enabled_set) == 1:
        enabled = enabled_set.pop()
    else:
        enabled = "irregular"

    available_adapters: list[str] = sorted(set().union(*(status.available_adapters for status in layer_status)))

    # ideally, active adapters should be consistent across all layers of the model, but we cannot guarantee it
    all_active_adapters: set[tuple[str, ...]] = {tuple(status.active_adapters) for status in layer_status}
    active_adapters: list[str] | Literal["irregular"]
    if not all_active_adapters:
        active_adapters = []
    elif len(all_active_adapters) == 1:
        active_adapters = list(all_active_adapters.pop())
    else:
        active_adapters = "irregular"

    # Here we determine what adapters are merged. This is not trivial because multiple adapters can be merged or not at
    # the same time. Some layers may only have adapter A, some only adapter B, so it's not as easy as just checking
    # which adapters are merged on each layer.

    # First, determine all adapters that are merged on at least on module.
    merged_all: set[str] = set()
    for status in layer_status:
        merged_all.update(status.merged_adapters)

    # Next, check if on any layer, on of these adapters is not merged.
    merged_adapters: list[str] | Literal["irregular"] = sorted(merged_all)
    for status in layer_status:
        unmerged = set(status.available_adapters) - set(status.merged_adapters)
        if unmerged & merged_all:
            # there is overlap between unmerged adapters and adapters that should be merged
            merged_adapters = "irregular"
            break

    # check status of requires_grad
    # first, merge the values for all layers
    requires_grad_all: dict[str, list[bool | Literal["irregular"]]] = collections.defaultdict(list)
    for status in layer_status:
        for key, val in status.requires_grad.items():
            requires_grad_all[key].append(val)

    # then, check if the values are consistent
    def check_irrgular(vals: list[bool | Literal["irregular"]]) -> bool | Literal["irregular"]:
        if all(val is True for val in vals):
            return True
        if all(val is False for val in vals):
            return False
        return "irregular"

    requires_grad = {key: check_irrgular(vals) for key, vals in requires_grad_all.items()}

    devices_dd = collections.defaultdict(list)
    for status in layer_status:
        for key, val in status.devices.items():
            devices_dd[key].extend(val)
    devices = {key: sorted(set(val)) for key, val in devices_dd.items()}

    adapter_model_status = TunerModelStatus(
        base_model_type=base_model_type,
        adapter_model_type=adapter_model_type,
        peft_types=peft_types,
        trainable_params=trainable_params,
        total_params=total_params,
        num_adapter_layers=num_adapter_layers,
        enabled=enabled,
        active_adapters=active_adapters,
        merged_adapters=merged_adapters,
        requires_grad=requires_grad,
        available_adapters=available_adapters,
        devices=devices,
    )
    return adapter_model_status