in optimum/exporters/openvino/model_patcher.py [0:0]
def patched_forward(*args, **kwargs):
from transformers.cache_utils import DynamicCache
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
return_legacy_cache = False
pkv_in_args = False
legacy_pkv = None
if "past_key_values" in kwargs:
legacy_pkv = kwargs.pop("past_key_values", None)
sign_names = list(signature.parameters.keys())
pkv_argument_index = sign_names.index("past_key_values")
cache_position_index = sign_names.index("cache_position") if "cache_position" in sign_names else -1
input_ids_index = sign_names.index("input_ids" if "input_ids" in sign_names else "inputs_embeds")
if legacy_pkv is None and len(args) > pkv_argument_index:
legacy_pkv = args[pkv_argument_index]
pkv_in_args = True
if legacy_pkv is not None:
pkv = DynamicCache.from_legacy_cache(legacy_pkv)
return_legacy_cache = True
if not pkv_in_args:
kwargs["past_key_values"] = pkv
else:
args[pkv_argument_index] = pkv
if (
return_legacy_cache
and cache_position_index != -1
and (cache_position_index > len(args) and "cache_position" not in kwargs)
):
past_seen_tokens = legacy_pkv[0][0].shape[-2]
input_ids = args[input_ids_index] if "input_ids" not in kwargs else kwargs["input_ids"]
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
)
kwargs["cache_position"] = cache_position
outputs = self.orig_forward(*args, **kwargs)
if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
return outputs