in src/exporters/coreml/config.py [0:0]
def get_flexible_outputs(self) -> Mapping[str, List[Mapping[str, int]]]:
"""
Determines which outputs require flexible shapes and on which axes.
Flexible output shapes are used when `sequence_length` on the model input is a range of
allowed lengths.
"""
output_shapes = {}
# Only tasks that output a sequence need a flexible output shape.
if self.task in [
"feature-extraction",
"text-generation",
"automatic-speech-recognition",
"fill-mask",
"question-answering",
"text2text-generation",
"speech-seq2seq",
"token-classification",
]:
input_descs = self.inputs
output_descs = self.outputs
# If this model has flexible input shapes, it also needs flexible output shapes.
min_length, max_length = None, None
if self.use_past or self.seq2seq:
min_length, max_length = 1, -1
else:
sequence_length = self.get_input_sequence_length(input_descs)
if isinstance(sequence_length, tuple):
min_length, max_length = sequence_length
if min_length is not None:
for key in ["last_hidden_state", "logits", "start_logits", "end_logits"]:
if key in output_descs:
output_shapes[key] = [
#{ "axis": 0, "min": 1, "max": -1 }, # batch size # TODO
{ "axis": 1, "min": min_length, "max": max_length },
]
if self.use_past:
# TODO: Temporarily disabled until we can solve the issue with encoder past key/values
#name = "decoder_present" if self.seq2seq == "decoder" else "present"
name = "present"
for i in range(self.num_layers):
output_shapes[f"{name}_{i}_key"] = [
#{ "axis": 0, "min": 1, "max": -1 }, # batch size # TODO
{ "axis": 2, "min": 1, "max": -1 },
]
output_shapes[f"{name}_{i}_value"] = [
#{ "axis": 0, "min": 1, "max": -1 }, # batch size # TODO
{ "axis": 2, "min": 1, "max": -1 },
]
# TODO: Temporarily disabled until we can solve the issue with encoder past key/values
# if self.seq2seq == "decoder":
# for i in range(self.num_encoder_layers):
# output_shapes[f"encoder_present_{i}_key"] = [
# { "axis": 2, "min": 1, "max": -1 },
# ]
# output_shapes[f"encoder_present_{i}_value"] = [
# { "axis": 2, "min": 1, "max": -1 },
# ]
return output_shapes