def _revert()

in optimum/bettertransformer/models/base.py [0:0]


    def _revert(self, module: torch.nn.Module) -> torch.nn.Module:
        if self.module_mapping is not None:
            if "" in self.module_mapping.values():
                for bt_module_attr_name, value in self.module_mapping.items():
                    if value == "":
                        module = getattr(self, bt_module_attr_name)
                        return module
            else:
                raise NotImplementedError("replacing a submodule in revert is not supported")

        for modified_layer_key_names, original_layer_key_names in self.original_layers_mapping.items():
            if isinstance(original_layer_key_names, list):
                current_weight = getattr(self, modified_layer_key_names)

                # Split the current weight n chunks - this is useful to split
                # the qkv layers into q, k, v layers for example.
                split_index = current_weight.shape[0] // len(original_layer_key_names)
                for i, subparam_name in enumerate(original_layer_key_names):
                    if recurse_getattr(module, subparam_name) is None:
                        # this is for example the case if bias=False is set for a nn.Linear layer
                        continue

                    if module not in self.keys_to_ignore:
                        # TODO: remove the clone once https://github.com/huggingface/transformers/pull/27314 & https://github.com/huggingface/safetensors/pull/379 are released.
                        # Safetensors is bugged when using views of tensors.
                        parameter = current_weight[i * split_index : (i + 1) * split_index].clone()
                        if isinstance(recurse_getattr(module, subparam_name), torch.nn.Parameter):
                            parameter = torch.nn.Parameter(parameter)
                        recurse_setattr(module, subparam_name, parameter)
            elif isinstance(original_layer_key_names, str):
                if recurse_getattr(module, original_layer_key_names) is None:
                    # this is for example the case if bias=False is set for a nn.Linear layer
                    continue

                parameter = getattr(self, modified_layer_key_names)
                if isinstance(recurse_getattr(module, original_layer_key_names), torch.nn.Parameter):
                    parameter = torch.nn.Parameter(parameter)
                recurse_setattr(module, original_layer_key_names, parameter)
            else:
                raise ValueError(
                    f"Invalid type {type(modified_layer_key_names)} for `original_layers_mapping`",
                    " please use either `str` or `list`.",
                )
        return module