in optimum/exporters/openvino/model_patcher.py [0:0]
def __enter__(self):
if is_torch_version(">=", "2.1.0"):
if (
self._model.config.model_type in ["qwen2", "llama"]
and self._model.config._attn_implementation != "sdpa"
):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if self._model.config.model_type == "qwen2" and is_transformers_version("<", "4.48"):
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
if self._model.config.model_type == "llama" and is_transformers_version("<", "4.47"):
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
if self._internal_patcher is not None:
return self._internal_patcher.__enter__()
return super().__enter__()