in optimum/bettertransformer/transformation.py [0:0]
def replace_to_bettertransformer(model, config):
r"""
Replaces the current model to its `BetterTransformer` implementation. Loops recursively into the model and replaces the
`Layer` modules with its `BetterTransformer` correspondant model
- Step 1: Recurse over the modules of the model
- Step 2: Verify if the module `BetterTransformer` is present for that model
- Step 3: If yes, replace the `...Layer` module with the `...LayerBetterTransformer` modules
- Step 4: If not, yield an error.
- Step 5: Post process the potentially converted model by setting the `is_last_layer` attribute to `True` for the last `BetterTransformer` layer.
(done in `set_last_layer` function)
Args:
`model` (`torch.nn.Module`):
The input model to convert
`config` (`transformers.PreTrainedConfig`):
The configuration dictionary of the model
Returns:
The converted model
"""
for name, module in model.named_children():
if hasattr(module, "SCB"):
# 8-bit modules are not supported
raise ValueError(
"`load_in_8bit` and `BetterTransformers` are mutually exclusive",
" please pass a model that is not loaded in 8-bit.",
)
# replace the module if it is a transformer layer compatible with bettertransformer
target_classes = list(BetterTransformerManager.MODEL_MAPPING[config.model_type].keys())
# We may want to override methods without having to override whole modules.
# For example, some methods handle the mask generation, which we do not need when using PyTorch SDPA.
if config.model_type in BetterTransformerManager.OVERWRITE_METHODS:
for class_name, method_name_and_replacement in BetterTransformerManager.OVERWRITE_METHODS[
config.model_type
].items():
if module.__class__.__name__ == class_name:
method_name = method_name_and_replacement[0]
new_method = method_name_and_replacement[1]
setattr(module, method_name, types.MethodType(new_method, module))
should_replace_module = False
for target_class in target_classes:
should_replace_module = module.__class__.__name__ == target_class
if should_replace_module:
bettertransformer_module = BetterTransformerManager.MODEL_MAPPING[config.model_type][target_class](
module, config
)
model._modules[name] = bettertransformer_module
break
if len(list(module.children())) > 0 and should_replace_module is False:
# we may explicitly exclude part of the model to use BetterTransformer
if config.model_type not in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM or (
config.model_type in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM
and name not in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM[config.model_type]
):
replace_to_bettertransformer(module, config)
return model