def _check_add_weighted_adapter()

in src/peft/tuners/ia3/model.py [0:0]


    def _check_add_weighted_adapter(self, adapters: list[str]) -> tuple[str, str]:
        """
        Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
        model.
        """
        # Validate existence of adapters
        for adapter in adapters:
            if adapter not in self.peft_config:
                raise ValueError(f"Adapter {adapter} does not exist")

        # Check for conflicting modules_to_save
        modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
        if any(
            sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 for wrapper in modules_to_save_wrappers
        ):
            raise ValueError("Cannot add weighted adapters targeting the same module with modules_to_save.")

        # Ensure all adapters have compatible target and feedforward module types
        target_module_types = {type(self.peft_config[adapter].target_modules) for adapter in adapters}
        feedforward_module_types = {type(self.peft_config[adapter].feedforward_modules) for adapter in adapters}
        if len(target_module_types) > 1 or len(feedforward_module_types) > 1:
            raise ValueError("All adapter configs should have the same type for target and feedforward modules.")

        # Combine target and feedforward modules
        if str in target_module_types:
            new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
        else:
            new_target_modules = set.union(*(self.peft_config[adapter].target_modules for adapter in adapters))

        if str in feedforward_module_types:
            new_feedforward_modules = "|".join(
                f"({self.peft_config[adapter].feedforward_modules})" for adapter in adapters
            )
        else:
            new_feedforward_modules = set.union(
                *(self.peft_config[adapter].feedforward_modules for adapter in adapters)
            )

        return new_target_modules, new_feedforward_modules