def _maybe_include_all_linear_layers()

in src/peft/tuners/tuners_utils.py [0:0]


def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module) -> PeftConfig:
    """
    Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from
    the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py
    """
    if not hasattr(peft_config, "target_modules"):
        return peft_config

    # if `target_modules` is a string, convert to lower case and check if it matches "all-linear"
    if not (
        isinstance(peft_config.target_modules, str)
        and peft_config.target_modules.lower() == INCLUDE_LINEAR_LAYERS_SHORTHAND
    ):
        return peft_config

    linear_classes = (torch.nn.Linear, Conv1D)
    linear_names = ("Linear",)
    linear_module_names = set()
    for name, module in model.named_modules():
        # match with all linear classes.
        if isinstance(module, linear_classes):
            linear_module_names.add(name)
        elif isinstance(module, BaseTunerLayer) and any(n in type(module).__name__ for n in linear_names):
            # If the model already has adapter layers applied, then the "linear" layer is actually an adapter layer,
            # e.g. lora.Linear, and not nn.Linear. To target this layer, we don't want to check the layer type, as there
            # are many possible layer types (one for each PEFT method) and the list would quickly get out of date. Thus
            # we rely on the name of the layer class, which by convention is something like "Linear", "Linear4bit",
            # "HqqLoraLinear", ... in PEFT. It's not pretty but should generally work.
            # See 2390
            linear_module_names.add(name)

    # Try to remove linear layers that should not be targeted as best as possible. We have to rely on convention as
    # there are no hard rules to detect these modules.
    module_names_to_exclude = set()
    if isinstance(model, PreTrainedModel):
        output_emb = model.get_output_embeddings()
        if output_emb is not None:
            # ignore the last classification head for text generation models
            last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
            module_names_to_exclude.add(last_module_name)
        elif peft_config.task_type == TaskType.SEQ_CLS:
            # ignore classifier head for classification models (issue 2027)
            # there is no fix name for the classifier head, so check the common ones
            for name in SEQ_CLS_HEAD_NAMES:
                cls_head = getattr(model, name, None)
                if cls_head is not None:
                    last_module_name = [name for name, module in model.named_modules() if module is cls_head][0]
                    module_names_to_exclude.add(last_module_name)
                    break

    # we don't want nested LoRA layers, i.e. LoRA being applied to possibly existing lora_A, lora_B, etc.
    # see 2390
    for prefix, module in model.named_modules():
        if isinstance(module, BaseTunerLayer):
            for suffix, child in module.named_modules():
                if suffix:
                    module_names_to_exclude.add(f"{prefix}.{suffix}")

    linear_module_names -= module_names_to_exclude
    peft_config.target_modules = linear_module_names
    return peft_config