torchbenchmark/util/framework/huggingface/patch_hf.py (25 lines of code) (raw):
"""
Patch the transformer source code to enable optimizations.
"""
import os
import subprocess
import sys
from .model_factory import class_models
from transformers import AutoConfig, ReformerConfig, BigBirdConfig, BertConfig
PATCH_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "patches")
def cache_model(name: str):
import transformers
model_config = eval(class_models[name][2])
model_ctor = getattr(transformers, class_models[name][3])
model_ctor.from_config(model_config)
def patch_transformers():
import transformers
transformers_dir = os.path.dirname(transformers.__file__)
for patch_file in os.listdir(PATCH_DIR):
patch_file_fullpatch = os.path.join(PATCH_DIR, patch_file)
try:
subprocess.check_output(["patch", "-p1", "--forward", "-i", patch_file_fullpatch, "-r", "/tmp/rej"], cwd=transformers_dir)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)