in optimum/exporters/onnx/model_patcher.py [0:0]
def patched_forward(*args, **kwargs):
model_kwargs = self.model_kwargs
# setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention
# in https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/wavlm/modeling_wavlm.py#L496
# that calls https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/functional.py#L5334
model_kwargs["output_attentions"] = True
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs)
outputs = self.orig_forward(*args, **kwargs)
filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
return filterd_outputs