def patched_forward()

in optimum/exporters/openvino/model_patcher.py [0:0]


        def patched_forward(*args, **kwargs):
            from transformers.cache_utils import DynamicCache

            signature = inspect.signature(self.orig_forward)
            args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
            return_legacy_cache = False
            pkv_in_args = False
            legacy_pkv = None
            if "past_key_values" in kwargs:
                legacy_pkv = kwargs.pop("past_key_values", None)
            sign_names = list(signature.parameters.keys())
            pkv_argument_index = sign_names.index("past_key_values")
            cache_position_index = sign_names.index("cache_position") if "cache_position" in sign_names else -1
            input_ids_index = sign_names.index("input_ids" if "input_ids" in sign_names else "inputs_embeds")
            if legacy_pkv is None and len(args) > pkv_argument_index:
                legacy_pkv = args[pkv_argument_index]
                pkv_in_args = True
            if legacy_pkv is not None:
                pkv = DynamicCache.from_legacy_cache(legacy_pkv)
                return_legacy_cache = True
                if not pkv_in_args:
                    kwargs["past_key_values"] = pkv
                else:
                    args[pkv_argument_index] = pkv

            if (
                return_legacy_cache
                and cache_position_index != -1
                and (cache_position_index > len(args) and "cache_position" not in kwargs)
            ):
                past_seen_tokens = legacy_pkv[0][0].shape[-2]
                input_ids = args[input_ids_index] if "input_ids" not in kwargs else kwargs["input_ids"]
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
                )
                kwargs["cache_position"] = cache_position

            outputs = self.orig_forward(*args, **kwargs)
            if return_legacy_cache:
                outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

            return outputs