in src/exporters/coreml/config.py [0:0]
def _output_descriptions(self) -> "OrderedDict[str, OutputDescription]":
if self.task == "feature-extraction" or self.seq2seq == "encoder":
return OrderedDict(
[
(
"last_hidden_state",
OutputDescription(
"last_hidden_state",
"Sequence of hidden-states at the output of the last layer of the model",
)
),
]
)
if self.task in [
"image-classification",
"multiple-choice",
"next-sentence-prediction",
"text-classification",
]:
return OrderedDict(
[
(
"logits",
OutputDescription(
"probabilities",
"Probability of each category",
do_softmax=True,
)
),
(
"class_labels",
OutputDescription(
"classLabel",
"Category with the highest score",
)
),
]
)
if self.task in [
"masked-im",
"text-generation",
"text2text-generation",
]:
return OrderedDict(
[
(
"logits",
OutputDescription(
"logits",
"Classification scores (before softmax)",
do_softmax=False,
)
),
]
)
if self.task in [
"automatic-speech-recognition",
"fill-mask",
"speech-seq2seq",
"token-classification"
]:
return OrderedDict(
[
(
"logits",
OutputDescription(
"token_scores",
"Classification scores for each vocabulary token (after softmax)",
do_softmax=True,
)
),
]
)
if self.task == "object-detection":
return OrderedDict(
[
(
"logits",
OutputDescription(
"logits",
"Classification logits (including no-object) for all queries",
do_softmax=False,
)
),
(
"pred_boxes",
OutputDescription(
"pred_boxes",
"Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height)",
)
),
]
)
if self.task == "question-answering":
return OrderedDict(
[
(
"start_logits",
OutputDescription(
"start_scores",
"Span-start scores (after softmax)",
do_softmax=True,
)
),
(
"end_logits",
OutputDescription(
"end_scores",
"Span-end scores (after softmax)",
do_softmax=True,
)
),
]
)
if self.task == "semantic-segmentation":
return OrderedDict(
[
(
"logits",
OutputDescription(
"classLabels",
"Classification scores for each pixel",
do_softmax=False,
do_upsample=True,
do_argmax=True,
)
),
]
)
raise AssertionError(f"Unsupported task '{self.task}' for modality '{self.modality}'")