optimum/onnx/configuration.py (66 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from typing import TYPE_CHECKING, Any, Dict, Optional from transformers.file_utils import TensorType from transformers.utils import logging if TYPE_CHECKING: from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast logger = logging.get_logger(__name__) class EncoderOnnxConfig(OnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: return OrderedDict( [ ("input_ids", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}), ] ) @property def outputs(self) -> Dict[str, Dict[int, str]]: return OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}) class DecoderOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = OrderedDict( [ ("input_ids", {0: "batch", 1: "past_decoder_sequence + sequence"}), ("encoder_hidden_states", {0: "batch", 1: "encoder_sequence"}), ("encoder_attention_mask", {0: "batch", 1: "encoder_sequence"}), ] ) if self.use_past: self.fill_with_past_key_values_(common_inputs, direction="inputs") return common_inputs def generate_dummy_inputs( self, tokenizer: "PreTrainedTokenizerBase", batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Dict[str, Any]: import torch common_inputs = {} dummy_input = super().generate_dummy_inputs( tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) batch, encoder_seq_length = dummy_input["input_ids"].shape encoder_hidden_states_shape = (batch, encoder_seq_length, self._config.hidden_size) common_inputs["input_ids"] = dummy_input.pop("decoder_input_ids") common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape) common_inputs["encoder_attention_mask"] = dummy_input.pop("attention_mask") if "past_key_values" in dummy_input: common_inputs["past_key_values"] = dummy_input.pop("past_key_values") return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super(OnnxConfigWithPast, self).outputs self.fill_with_past_key_values_(common_outputs, direction="outputs") return common_outputs def fill_with_past_key_values_(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): num_pkv_per_layer = 4 _, num_decoder_layers = self.num_layers name = "past" if direction == "inputs" else "present" decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" for i in range(num_decoder_layers * num_pkv_per_layer): inputs_or_outputs[f"{name}_key_values_{i}"] = {0: "batch", 2: decoder_sequence}