def forward()

in src/exporters/coreml/convert.py [0:0]


        def forward(self, *all_inputs):
            remaining = len(all_inputs)
            inputs = all_inputs[0]

            # Core ML's image preprocessing does not allow a different scaling
            # factor for each color channel, so do this manually.
            if hasattr(self.preprocessor, "image_std") and not is_image_std_same(self.preprocessor):
                image_std = torch.tensor(self.preprocessor.image_std).reshape(1, -1, 1, 1)
                inputs = inputs / image_std

            model_kwargs = {
                "return_dict": False,

                # CoreMLConfig's values_override is supposed to do this, but not all
                # models look at self.config.use_cache (e.g. ElectraForCausalLM)
                # Can't do it here either because it doesn't work with all models!
                #"use_cache": self.config.use_past or self.config.seq2seq,
            }

            # Convert the past_key_values_x_key and _value inputs back into tuples,
            # as that is what the original model expects.
            # Assumes past_key_values are always the last inputs to the Wrapper.
            # An encoder-decoder model first gets all the decoder past_key_values
            # tensors, followed by the encoder ones, but they get combined into the
            # same 4-tuples.
            if self.config.use_past:
                # TODO: Temporarily disabled until we can solve the issue with encoder past key/values
                if False and self.config.seq2seq == "decoder":
                    num_decoder_layers = self.config.num_layers
                    num_encoder_layers = self.config.num_encoder_layers
                    remaining -= (num_decoder_layers + num_encoder_layers) * 2
                    past_key_values = []
                    for i in range(min(num_decoder_layers, num_encoder_layers)):
                        past_key_values.append((
                            all_inputs[remaining + i*2],
                            all_inputs[remaining + i*2 + 1],
                            all_inputs[remaining + num_decoder_layers*2 + i*2],
                            all_inputs[remaining + num_decoder_layers*2 + i*2 + 1],
                        ))
                    model_kwargs["past_key_values"] = past_key_values
                else:
                    remaining -= self.config.num_layers * 2
                    past_key_values = []
                    for i in range(self.config.num_layers):
                        past_key_values.append((
                            all_inputs[remaining + i*2],
                            all_inputs[remaining + i*2 + 1],
                        ))
                    model_kwargs["past_key_values"] = past_key_values

            if self.config.seq2seq == "decoder":
                model_kwargs["decoder_input_ids"] = all_inputs[0]
                model_kwargs["decoder_attention_mask"] = all_inputs[1]
                model_kwargs["encoder_outputs"] = (all_inputs[2],)
                if remaining >= 4:
                    model_kwargs["attention_mask"] = all_inputs[3]
            elif self.config.modality == "text":
                if remaining >= 2:
                    model_kwargs["attention_mask"] = all_inputs[1]
                if remaining >= 4:
                    # Special case for T5
                    model_kwargs["decoder_input_ids"] = all_inputs[2]
                    model_kwargs["decoder_attention_mask"] = all_inputs[3]
                elif remaining == 3:
                    model_kwargs["token_type_ids"] = all_inputs[2]
            elif self.config.modality == "vision":
                if self.config.task == "masked-im":
                    model_kwargs["bool_masked_pos"] = all_inputs[1]

            # Run the model with the provided inputs.
            if self.config.seq2seq == "encoder":
                outputs = self.model.get_encoder()(inputs, **model_kwargs)
            elif self.config.seq2seq == "decoder":
                outputs = self.model(**model_kwargs)
            else:
                outputs = self.model(inputs, **model_kwargs)

            # Unpack the output `past_key_values` into a single tuple.
            presents = ()
            if self.config.use_past:
                if len(outputs) < 2:
                    raise ValueError("expected at least two output tensors, got one")

                past_key_values_index = -2 if self.config.seq2seq == "decoder" else -1
                past_key_values = outputs[past_key_values_index]

                # TODO: Temporarily disabled until we can solve the issue with encoder past key/values
                if False and self.config.seq2seq == "decoder":
                    decoder_presents = ()
                    encoder_presents = ()
                    for i in range(len(past_key_values)):
                        for j in range(2):
                            decoder_presents = decoder_presents + (past_key_values[i][j],)
                            encoder_presents = encoder_presents + (past_key_values[i][j + 2],)

                    presents = decoder_presents + encoder_presents
                else:
                    for i in range(len(past_key_values)):
                        for j in range(2):
                            presents = presents + (past_key_values[i][j],)

            output_descs = self.config.outputs

            if self.config.task == "image-classification":
                output_desc = output_descs["logits"]
                if output_desc.do_softmax:
                    return torch.nn.functional.softmax(outputs[0], dim=1)
                else:
                    return outputs[0]  # logits

            if self.config.task == "masked-im":
                # Some models also return loss even if no labels provided (e.g. ViT)
                # so skip that output if it's present.
                return outputs[1] if len(outputs) >= 2 else outputs[0]  # logits

            if self.config.seq2seq != "encoder" and self.config.task in [
                "text-generation",
                "automatic-speech-recognition",
                "fill-mask",
                "multiple-choice",
                "next-sentence-prediction",
                "text2text-generation",
                "text-classification",
                "speech-seq2seq",
                "token-classification",
            ]:
                output_desc = output_descs["logits"]
                if output_desc.do_softmax:
                    prediction = torch.nn.functional.softmax(outputs[0], dim=-1)
                else:
                    prediction = outputs[0]  # logits

                return (prediction,) + presents

            if self.config.task == "object-detection":
                return outputs[0], outputs[1]  # logits, pred_boxes

            if self.config.task == "question-answering":
                output_desc = output_descs["start_logits"]
                if output_desc.do_softmax:
                    start_scores = torch.nn.functional.softmax(outputs[0], dim=-1)
                    end_scores = torch.nn.functional.softmax(outputs[1], dim=-1)
                    return start_scores, end_scores
                else:
                    return outputs[0], outputs[1]  # start_logits, end_logits

            if self.config.task == "semantic-segmentation":
                x = outputs[0]  # logits
                output_desc = output_descs["logits"]
                if output_desc.do_upsample:
                    x = torch.nn.functional.interpolate(x, size=inputs.shape[-2:], mode="bilinear", align_corners=False)
                if output_desc.do_softmax:
                    x = torch.nn.functional.softmax(x, dim=1)
                if output_desc.do_argmax:
                    x = x.argmax(1)
                return x

            if self.config.seq2seq == "encoder" and self.config.task in ["text2text-generation", "speech-seq2seq"]:
                return outputs[0]  # last_hidden_state

            if self.config.task == "feature-extraction":
                if self.config.use_past:
                    return (outputs[0],) + presents
                elif len(output_descs) > 1 and len(outputs) > 1:
                    return outputs[0], outputs[1]  # last_hidden_state, pooler_output
                else:
                    return outputs[0]  # last_hidden_state

            raise AssertionError(f"Cannot compute outputs for unknown task '{self.config.task}'")