in src/exporters/coreml/convert.py [0:0]
def forward(self, *all_inputs):
remaining = len(all_inputs)
inputs = all_inputs[0]
# Core ML's image preprocessing does not allow a different scaling
# factor for each color channel, so do this manually.
if hasattr(self.preprocessor, "image_std") and not is_image_std_same(self.preprocessor):
image_std = torch.tensor(self.preprocessor.image_std).reshape(1, -1, 1, 1)
inputs = inputs / image_std
model_kwargs = {
"return_dict": False,
# CoreMLConfig's values_override is supposed to do this, but not all
# models look at self.config.use_cache (e.g. ElectraForCausalLM)
# Can't do it here either because it doesn't work with all models!
#"use_cache": self.config.use_past or self.config.seq2seq,
}
# Convert the past_key_values_x_key and _value inputs back into tuples,
# as that is what the original model expects.
# Assumes past_key_values are always the last inputs to the Wrapper.
# An encoder-decoder model first gets all the decoder past_key_values
# tensors, followed by the encoder ones, but they get combined into the
# same 4-tuples.
if self.config.use_past:
# TODO: Temporarily disabled until we can solve the issue with encoder past key/values
if False and self.config.seq2seq == "decoder":
num_decoder_layers = self.config.num_layers
num_encoder_layers = self.config.num_encoder_layers
remaining -= (num_decoder_layers + num_encoder_layers) * 2
past_key_values = []
for i in range(min(num_decoder_layers, num_encoder_layers)):
past_key_values.append((
all_inputs[remaining + i*2],
all_inputs[remaining + i*2 + 1],
all_inputs[remaining + num_decoder_layers*2 + i*2],
all_inputs[remaining + num_decoder_layers*2 + i*2 + 1],
))
model_kwargs["past_key_values"] = past_key_values
else:
remaining -= self.config.num_layers * 2
past_key_values = []
for i in range(self.config.num_layers):
past_key_values.append((
all_inputs[remaining + i*2],
all_inputs[remaining + i*2 + 1],
))
model_kwargs["past_key_values"] = past_key_values
if self.config.seq2seq == "decoder":
model_kwargs["decoder_input_ids"] = all_inputs[0]
model_kwargs["decoder_attention_mask"] = all_inputs[1]
model_kwargs["encoder_outputs"] = (all_inputs[2],)
if remaining >= 4:
model_kwargs["attention_mask"] = all_inputs[3]
elif self.config.modality == "text":
if remaining >= 2:
model_kwargs["attention_mask"] = all_inputs[1]
if remaining >= 4:
# Special case for T5
model_kwargs["decoder_input_ids"] = all_inputs[2]
model_kwargs["decoder_attention_mask"] = all_inputs[3]
elif remaining == 3:
model_kwargs["token_type_ids"] = all_inputs[2]
elif self.config.modality == "vision":
if self.config.task == "masked-im":
model_kwargs["bool_masked_pos"] = all_inputs[1]
# Run the model with the provided inputs.
if self.config.seq2seq == "encoder":
outputs = self.model.get_encoder()(inputs, **model_kwargs)
elif self.config.seq2seq == "decoder":
outputs = self.model(**model_kwargs)
else:
outputs = self.model(inputs, **model_kwargs)
# Unpack the output `past_key_values` into a single tuple.
presents = ()
if self.config.use_past:
if len(outputs) < 2:
raise ValueError("expected at least two output tensors, got one")
past_key_values_index = -2 if self.config.seq2seq == "decoder" else -1
past_key_values = outputs[past_key_values_index]
# TODO: Temporarily disabled until we can solve the issue with encoder past key/values
if False and self.config.seq2seq == "decoder":
decoder_presents = ()
encoder_presents = ()
for i in range(len(past_key_values)):
for j in range(2):
decoder_presents = decoder_presents + (past_key_values[i][j],)
encoder_presents = encoder_presents + (past_key_values[i][j + 2],)
presents = decoder_presents + encoder_presents
else:
for i in range(len(past_key_values)):
for j in range(2):
presents = presents + (past_key_values[i][j],)
output_descs = self.config.outputs
if self.config.task == "image-classification":
output_desc = output_descs["logits"]
if output_desc.do_softmax:
return torch.nn.functional.softmax(outputs[0], dim=1)
else:
return outputs[0] # logits
if self.config.task == "masked-im":
# Some models also return loss even if no labels provided (e.g. ViT)
# so skip that output if it's present.
return outputs[1] if len(outputs) >= 2 else outputs[0] # logits
if self.config.seq2seq != "encoder" and self.config.task in [
"text-generation",
"automatic-speech-recognition",
"fill-mask",
"multiple-choice",
"next-sentence-prediction",
"text2text-generation",
"text-classification",
"speech-seq2seq",
"token-classification",
]:
output_desc = output_descs["logits"]
if output_desc.do_softmax:
prediction = torch.nn.functional.softmax(outputs[0], dim=-1)
else:
prediction = outputs[0] # logits
return (prediction,) + presents
if self.config.task == "object-detection":
return outputs[0], outputs[1] # logits, pred_boxes
if self.config.task == "question-answering":
output_desc = output_descs["start_logits"]
if output_desc.do_softmax:
start_scores = torch.nn.functional.softmax(outputs[0], dim=-1)
end_scores = torch.nn.functional.softmax(outputs[1], dim=-1)
return start_scores, end_scores
else:
return outputs[0], outputs[1] # start_logits, end_logits
if self.config.task == "semantic-segmentation":
x = outputs[0] # logits
output_desc = output_descs["logits"]
if output_desc.do_upsample:
x = torch.nn.functional.interpolate(x, size=inputs.shape[-2:], mode="bilinear", align_corners=False)
if output_desc.do_softmax:
x = torch.nn.functional.softmax(x, dim=1)
if output_desc.do_argmax:
x = x.argmax(1)
return x
if self.config.seq2seq == "encoder" and self.config.task in ["text2text-generation", "speech-seq2seq"]:
return outputs[0] # last_hidden_state
if self.config.task == "feature-extraction":
if self.config.use_past:
return (outputs[0],) + presents
elif len(output_descs) > 1 and len(outputs) > 1:
return outputs[0], outputs[1] # last_hidden_state, pooler_output
else:
return outputs[0] # last_hidden_state
raise AssertionError(f"Cannot compute outputs for unknown task '{self.config.task}'")