in optimum/exporters/onnx/base.py [0:0]
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
if direction == "inputs":
decoder_sequence_name = "past_decoder_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_decoder_sequence_length + 1"
name = "present"
for i in range(self._normalized_config.decoder_num_layers):
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name}
if (
self.is_merged is True
or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs)
or direction == "inputs"
):
# TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export()
# time we have currently no case to check whether we will merge at a later step or not (self.is_merged is
# not yet set at this time)
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length_out"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"}