in optimum/exporters/onnx/model_patcher.py [0:0]
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.super_patched_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
outputs = self.super_patched_forward(*args, **kwargs)
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filtered_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
if name != "past_key_values":
if self.real_config._behavior == "decoder" and name == "encoder_last_hidden_state":
# Who cares about the encoder outputs in the decoder?
continue
else:
filtered_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filtered_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filtered_outputs[name] = tuple([v[:2] for v in value])
return filtered_outputs