in optimum/exporters/onnx/model_patcher.py [0:0]
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
if is_transformers_version(">=", "4.48"):
if "past_key_values" in signature.parameters:
pkv_index = list(signature.parameters.keys()).index("past_key_values")
if (
pkv_index < len(args) # pkv is in args
and isinstance(args[pkv_index], (list, tuple))
and isinstance(args[pkv_index][0], (list, tuple))
):
if len(args[pkv_index][0]) == 2:
args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
elif len(args[pkv_index][0]) == 4:
args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
)
elif (
"past_key_values" in kwargs # pkv is in kwargs
and isinstance(kwargs["past_key_values"], (list, tuple))
and isinstance(kwargs["past_key_values"][0], (list, tuple))
):
if len(kwargs["past_key_values"][0]) == 2:
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
elif len(kwargs["past_key_values"][0]) == 4:
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
kwargs["past_key_values"]
)
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
)
outputs = self.orig_forward(*args, **kwargs)
# This code block handles different cases of the filterd_outputs input to align it with the expected
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filtered_outputs = {}
if isinstance(outputs, dict):
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filtered_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
filtered_outputs = dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
outputs_str = ", ".join(config.outputs.keys())
raise ValueError(
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
)
else:
name = list(config.outputs.keys())[0]
filtered_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filtered_outputs[name] = outputs
if is_transformers_version(">=", "4.48"):
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
return filtered_outputs