in optimum/exporters/onnx/base.py [0:0]
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = self._onnx_config.generate_dummy_inputs(framework=framework, **kwargs)
input_name, _ = next(iter(self._onnx_config.inputs.items()))
batch_size = dummy_inputs[input_name].shape[0]
# TODO: doesn't this break attention_mask generation?
if (
isinstance(self._onnx_config, OnnxConfigWithPast)
and self._onnx_config.use_past_in_inputs is True
and self.task != "text-generation"
):
kwargs["sequence_length"] = 1
else:
for input_name, dynamic_axes in self._tasks_to_extra_inputs[self.task].items():
if "sequence_length" in dynamic_axes.values():
kwargs["sequence_length"] = DEFAULT_DUMMY_SHAPES["sequence_length"]
kwargs["num_labels"] = self._onnx_config._config.num_labels
dummy_inputs_generators = [
cls_(self.task, self._normalized_config, batch_size=batch_size, **kwargs)
for cls_ in self.DUMMY_EXTRA_INPUT_GENERATOR_CLASSES
]
for input_name in self._tasks_to_extra_inputs[self.task]:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)
return dummy_inputs