in src/peft/peft_model.py [0:0]
def get_model_status(model: torch.nn.Module) -> TunerModelStatus:
"""Get the status of tuners of the model.
This function returns a `TunerModelStatus` dataclass instance, which contains the following attributes:
- `base_model_type` (`str`):
The type of the base model, e.g. `T5Model`.
- `adapter_model_type` (`str`):
The type of the adapter model, e.g. `LoraModel`.
- `peft_types` (`dict[str, str]`):
The mapping of adapter name to adapter type, e.g. `{"default": "LORA"}`.
- `trainable_params` (`int`):
The number of trainable parameters in the model.
- `total_params` (`int`):
The total number of parameters in the model.
- `num_adapter_layers` (`int`):
The number of adapter layers in the model.
- `enabled` (`bool`, `Literal["irregular"]`):
Whether all adapter layers are enabled. If some are enabled and some are not, this will be `"irregular"`. This
means that your model is in an inconsistent state and might not work as expected.
- `active_adapters` (`list[str]`, `Literal["irregular"]`):
The names of the active adapters. If the active adapters are not consistent across all layers, this will be
`"irregular"`, which means that your model is in an inconsistent state and might not work as expected.
- `merged_adapters` (`list[str]`, `Literal["irregular"]`):
The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be
`"irregular"`, which means that your model is in an inconsistent state and might not work as expected.
- `requires_grad` (`dict[str, bool | Literal["irregular"]]`):
Whether for the given adapter, all adapter layers have `requires_grad` set to `True` or `False`. If there is a
mix, this will be set to `"irregular"`, which means that your model is in an inconsistent state and might not
work as expected.
- `available_adapters` (`list[str]`):
The names of the available adapters, e.g. `["default"]`.
- `devices` (`dict[str, list[str]]`):
The devices where the parameters of the given adapter are stored, e.g. `["cuda"]`.
Args:
model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]):
The model to get the adapter layer status from.
Returns:
`peft.peft_model.TunerModelStatus`:
A dataclass containing the status of the model.
"""
if isinstance(model, PeftModel):
if not isinstance(model.base_model, BaseTuner):
raise TypeError(
"get_model_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not "
"supported."
)
base_model_type = model.get_base_model().__class__.__name__
trainable_params, total_params = model.get_nb_trainable_parameters()
base_model = model.base_model
peft_types = {key: str(config.peft_type).partition(".")[-1] for key, config in base_model.peft_config.items()}
adapter_model_type = base_model.__class__.__name__
elif isinstance(model, PreTrainedModel):
base_model_type = model.__class__.__name__
trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model)
base_model = model
peft_types = {}
adapter_model_type = "None"
else:
base_model_type = "other"
trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model)
base_model = model
peft_types = {}
adapter_model_type = "None"
layer_status = get_layer_status(model)
num_adapter_layers = len(layer_status)
enabled_set: set[bool] = {status.enabled for status in layer_status} # must be {True}, {False}, or {True, False}
enabled: bool | Literal["irregular"]
if len(enabled_set) == 1:
enabled = enabled_set.pop()
else:
enabled = "irregular"
available_adapters: list[str] = sorted(set().union(*(status.available_adapters for status in layer_status)))
# ideally, active adapters should be consistent across all layers of the model, but we cannot guarantee it
all_active_adapters: set[tuple[str, ...]] = {tuple(status.active_adapters) for status in layer_status}
active_adapters: list[str] | Literal["irregular"]
if not all_active_adapters:
active_adapters = []
elif len(all_active_adapters) == 1:
active_adapters = list(all_active_adapters.pop())
else:
active_adapters = "irregular"
# Here we determine what adapters are merged. This is not trivial because multiple adapters can be merged or not at
# the same time. Some layers may only have adapter A, some only adapter B, so it's not as easy as just checking
# which adapters are merged on each layer.
# First, determine all adapters that are merged on at least on module.
merged_all: set[str] = set()
for status in layer_status:
merged_all.update(status.merged_adapters)
# Next, check if on any layer, on of these adapters is not merged.
merged_adapters: list[str] | Literal["irregular"] = sorted(merged_all)
for status in layer_status:
unmerged = set(status.available_adapters) - set(status.merged_adapters)
if unmerged & merged_all:
# there is overlap between unmerged adapters and adapters that should be merged
merged_adapters = "irregular"
break
# check status of requires_grad
# first, merge the values for all layers
requires_grad_all: dict[str, list[bool | Literal["irregular"]]] = collections.defaultdict(list)
for status in layer_status:
for key, val in status.requires_grad.items():
requires_grad_all[key].append(val)
# then, check if the values are consistent
def check_irrgular(vals: list[bool | Literal["irregular"]]) -> bool | Literal["irregular"]:
if all(val is True for val in vals):
return True
if all(val is False for val in vals):
return False
return "irregular"
requires_grad = {key: check_irrgular(vals) for key, vals in requires_grad_all.items()}
devices_dd = collections.defaultdict(list)
for status in layer_status:
for key, val in status.devices.items():
devices_dd[key].extend(val)
devices = {key: sorted(set(val)) for key, val in devices_dd.items()}
adapter_model_status = TunerModelStatus(
base_model_type=base_model_type,
adapter_model_type=adapter_model_type,
peft_types=peft_types,
trainable_params=trainable_params,
total_params=total_params,
num_adapter_layers=num_adapter_layers,
enabled=enabled,
active_adapters=active_adapters,
merged_adapters=merged_adapters,
requires_grad=requires_grad,
available_adapters=available_adapters,
devices=devices,
)
return adapter_model_status