def patched_forward()

in optimum/exporters/onnx/model_patcher.py [0:0]


        def patched_forward(*args, **kwargs):
            signature = inspect.signature(self.orig_forward)
            args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

            if is_transformers_version(">=", "4.48"):
                if "past_key_values" in signature.parameters:
                    pkv_index = list(signature.parameters.keys()).index("past_key_values")

                    if (
                        pkv_index < len(args)  # pkv is in args
                        and isinstance(args[pkv_index], (list, tuple))
                        and isinstance(args[pkv_index][0], (list, tuple))
                    ):
                        if len(args[pkv_index][0]) == 2:
                            args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
                        elif len(args[pkv_index][0]) == 4:
                            args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
                        else:
                            raise ValueError(
                                f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
                            )
                    elif (
                        "past_key_values" in kwargs  # pkv is in kwargs
                        and isinstance(kwargs["past_key_values"], (list, tuple))
                        and isinstance(kwargs["past_key_values"][0], (list, tuple))
                    ):
                        if len(kwargs["past_key_values"][0]) == 2:
                            kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
                        elif len(kwargs["past_key_values"][0]) == 4:
                            kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
                                kwargs["past_key_values"]
                            )
                        else:
                            raise ValueError(
                                f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
                            )

            outputs = self.orig_forward(*args, **kwargs)

            # This code block handles different cases of the filterd_outputs input to align it with the expected
            # format of outputs. It is common for the output type of a model to vary, such as tensor, list,
            # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
            # contains the output names of the model. In the case of Timm classification models, the output
            # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
            # match the outputs in order.
            filtered_outputs = {}
            if isinstance(outputs, dict):
                for name, value in outputs.items():
                    onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
                    if (
                        onnx_output_name in config.outputs
                        or (allow_past_in_outputs and name.startswith("past_key_values"))
                        or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
                    ):
                        filtered_outputs[name] = value
            elif isinstance(outputs, (list, tuple)):
                outputs_list = list(config.outputs.keys())
                filtered_outputs = dict(zip(outputs_list, outputs))
            else:
                if len(config.outputs) > 1:
                    num_outputs = len(config.outputs)
                    outputs_str = ", ".join(config.outputs.keys())
                    raise ValueError(
                        f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
                    )
                else:
                    name = list(config.outputs.keys())[0]
                    filtered_outputs[name] = outputs
                name = list(config.outputs.keys())[0]
                filtered_outputs[name] = outputs

            if is_transformers_version(">=", "4.48"):
                if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
                    filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()

            return filtered_outputs