in optimum/bettertransformer/transformation.py [0:0]
def reverse(bt_model: "PreTrainedModel") -> "PreTrainedModel":
"""
Converts back a model using BetterTransformer to its canonical transformers modeling implementation, in order to save
and share it.
Args:
bt_model (`PreTrainedModel`):
Model using BetterTransform to convert back to use transformers modeling.
Returns:
PreTrainedModel: _description_
"""
if getattr(bt_model, "use_bettertransformer", False) is False:
raise ValueError(
"The method BetterTransformer.reverse() should be used on a model already transformed to the BetterTransformer"
" format, which appears to not be the case."
)
if parse(torch.__version__) <= parse("1.14"):
raise ValueError(
f"BetterTransformer reverse transform requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch."
)
config = bt_model.config
if config.model_type not in ["wav2vec2", "hubert", "bark"]:
with torch.device("meta"):
reversed_model = bt_model.__class__(config)
else:
# TODO: fix once this is fixed in pytorch
# reference: https://github.com/pytorch/pytorch/issues/96409
logger.warning(
"The reverse transform for the architectures wav2vec2, hubert, bark is memory-heavy due to a bug in PyTorch."
)
reversed_model = bt_model.__class__(config)
if bt_model.training is False:
reversed_model = reversed_model.eval()
reversed_modules_paths = []
for path, module in reversed_model.named_modules():
if path.startswith(tuple(reversed_modules_paths)):
continue
if config.model_type in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM and any(
subname in path for subname in BetterTransformerManager.EXCLUDE_FROM_TRANSFORM[config.model_type]
):
continue
target_classes = list(BetterTransformerManager.MODEL_MAPPING[config.model_type].keys())
has_been_replaced = False
for target_class in target_classes:
if module.__class__.__name__ == target_class:
has_been_replaced = True
break
# replace parameters, buffers (or possibly full modules) that were modified by the bettertransformer transform
if has_been_replaced:
recurse_setattr(reversed_model, path, recurse_getattr(bt_model, path)._revert(module))
reversed_modules_paths.append(path + ".") # add a . to avoid issues with startswith
# replace back parameters and buffers that were untouched by the bettertransformer transform
for path, param in reversed_model.state_dict().items():
if param.device == torch.device("meta") or not path.startswith(tuple(reversed_modules_paths)):
recurse_setattr(reversed_model, path, recurse_getattr(bt_model, path))
# some buffers may be non-persistent, hence not in the state_dict (as token_type_ids for some models)
for path, param in reversed_model.named_buffers():
if param.device == torch.device("meta") or not path.startswith(tuple(reversed_modules_paths)):
recurse_setattr(reversed_model, path, recurse_getattr(bt_model, path))
return reversed_model