optimum/exporters/neuron/config.py (51 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Common Neuron configuration classes that handle most of the features for building model specific configurations. """ from typing import List from ...utils import ( DummyAudioInputGenerator, DummyBboxInputGenerator, DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, DummyVisionInputGenerator, logging, ) from .base import NeuronDefaultConfig logger = logging.get_logger(__name__) class TextEncoderNeuronConfig(NeuronDefaultConfig): """ Handles encoder-based text architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) INPUT_ARGS = ("batch_size", "sequence_length", ("multiple-choice", "num_choices")) class VisionNeuronConfig(NeuronDefaultConfig): """ Handles vision architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) INPUT_ARGS = ("batch_size", "num_channels", "width", "height") class TextAndVisionNeuronConfig(NeuronDefaultConfig): """ Handles multi-modal text and vision architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator) class AudioNeuronConfig(NeuronDefaultConfig): """ Handles audio architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator, DummyTextInputGenerator) INPUT_ARGS = ("batch_size", "audio_sequence_length") class TextSeq2SeqNeuronConfig(NeuronDefaultConfig): """ Handles encoder-decoder-based text architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, ) def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( self.task, self._normalized_config, **kwargs ) dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1]( self.task, self._normalized_config, **kwargs, ) dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2]( self.task, self._normalized_config, encoder_sequence_length=dummy_text_input_generator.sequence_length, **kwargs, ) dummy_inputs_generators = [ dummy_text_input_generator, dummy_decoder_text_input_generator, dummy_seq2seq_past_key_values_generator, ] return dummy_inputs_generators