def reverse()

in optimum/bettertransformer/transformation.py [0:0]


    def reverse(bt_model: "PreTrainedModel") -> "PreTrainedModel":
        """
        Converts back a model using BetterTransformer to its canonical transformers modeling implementation, in order to save
        and share it.

        Args:
            bt_model (`PreTrainedModel`):
                Model using BetterTransform to convert back to use transformers modeling.

        Returns:
            PreTrainedModel: _description_
        """
        if getattr(bt_model, "use_bettertransformer", False) is False:
            raise ValueError(
                "The method BetterTransformer.reverse() should be used on a model already transformed to the BetterTransformer"
                " format, which appears to not be the case."
            )

        if parse(torch.__version__) <= parse("1.14"):
            raise ValueError(
                f"BetterTransformer reverse transform requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch."
            )
        config = bt_model.config

        if config.model_type not in ["wav2vec2", "hubert", "bark"]:
            with torch.device("meta"):
                reversed_model = bt_model.__class__(config)
        else:
            # TODO: fix once this is fixed in pytorch
            # reference: https://github.com/pytorch/pytorch/issues/96409
            logger.warning(
                "The reverse transform for the architectures wav2vec2, hubert, bark is memory-heavy due to a bug in PyTorch."
            )
            reversed_model = bt_model.__class__(config)

        if bt_model.training is False:
            reversed_model = reversed_model.eval()

        reversed_modules_paths = []
        for path, module in reversed_model.named_modules():
            if path.startswith(tuple(reversed_modules_paths)):
                continue

            if config.model_type in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM and any(
                subname in path for subname in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM[config.model_type]
            ):
                continue

            target_classes = list(BetterTransformerManager.MODEL_MAPPING[config.model_type].keys())
            has_been_replaced = False
            for target_class in target_classes:
                if module.__class__.__name__ == target_class:
                    has_been_replaced = True
                    break

            # replace parameters, buffers (or possibly full modules) that were modified by the bettertransformer transform
            if has_been_replaced:
                recurse_setattr(reversed_model, path, recurse_getattr(bt_model, path)._revert(module))
                reversed_modules_paths.append(path + ".")  # add a . to avoid issues with startswith

        # replace back parameters and buffers that were untouched by the bettertransformer transform
        for path, param in reversed_model.state_dict().items():
            if param.device == torch.device("meta") or not path.startswith(tuple(reversed_modules_paths)):
                recurse_setattr(reversed_model, path, recurse_getattr(bt_model, path))

        # some buffers may be non-persistent, hence not in the state_dict (as token_type_ids for some models)
        for path, param in reversed_model.named_buffers():
            if param.device == torch.device("meta") or not path.startswith(tuple(reversed_modules_paths)):
                recurse_setattr(reversed_model, path, recurse_getattr(bt_model, path))

        return reversed_model