in optimum/exporters/onnx/model_configs.py [0:0]
def inputs_for_causal_lm(self):
if self.use_past_in_inputs:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"},
}
for i in range(self._normalized_config.decoder_num_layers):
common_inputs[f"past_key_values.{i}.key"] = {
0: "batch_size",
2: "past_sequence_length",
}
common_inputs[f"past_key_values.{i}.value"] = {
0: "batch_size",
2: "past_sequence_length",
}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
return common_inputs