optimum/neuron/models/bert/model.py (244 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """BERT model on Neuron devices.""" import logging from typing import Optional import torch from transformers import ( AutoModel, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForTokenClassification, ) from transformers.modeling_outputs import ( BaseModelOutputWithPooling, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from ...modeling_traced import NeuronTracedModel from ...utils.doc import ( _TOKENIZER_FOR_DOC, NEURON_FEATURE_EXTRACTION_EXAMPLE, NEURON_MASKED_LM_EXAMPLE, NEURON_MODEL_START_DOCSTRING, NEURON_MULTIPLE_CHOICE_EXAMPLE, NEURON_QUESTION_ANSWERING_EXAMPLE, NEURON_SEQUENCE_CLASSIFICATION_EXAMPLE, NEURON_TEXT_INPUTS_DOCSTRING, NEURON_TOKEN_CLASSIFICATION_EXAMPLE, add_start_docstrings, add_start_docstrings_to_model_forward, ) logger = logging.getLogger(__name__) @add_start_docstrings( """ Bare Bert Model transformer outputting raw hidden-states without any specific head on top, used for the task "feature-extraction". """, NEURON_MODEL_START_DOCSTRING, ) class NeuronBertModel(NeuronTracedModel): auto_model_class = AutoModel @add_start_docstrings_to_model_forward( NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + NEURON_FEATURE_EXTRACTION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="NeuronBertModel", checkpoint="optimum/bert-base-uncased-neuronx-bs1-sq128", ) ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ): neuron_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: neuron_inputs["token_type_ids"] = token_type_ids with self.neuron_padding_manager(neuron_inputs) as inputs: outputs = self.model(*inputs) # last_hidden_state -> (batch_size, sequencen_len, hidden_size) last_hidden_state = self.remove_padding( [outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]] )[0] # Remove padding on batch_size(0), and sequence_length(1) if len(outputs) > 1: # pooler_output -> (batch_size, hidden_size) pooler_output = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[ 0 ] # Remove padding on batch_size(0) else: pooler_output = None return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output) @add_start_docstrings( """ Masked language Bert Model with a `language modeling` head on top, for masked language modeling tasks on Neuron devices. """, NEURON_MODEL_START_DOCSTRING, ) class NeuronBertForMaskedLM(NeuronTracedModel): auto_model_class = AutoModelForMaskedLM @add_start_docstrings_to_model_forward( NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + NEURON_MASKED_LM_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="NeuronBertForMaskedLM", checkpoint="optimum/legal-bert-base-uncased-neuronx", ) ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ): neuron_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: neuron_inputs["token_type_ids"] = token_type_ids with self.neuron_padding_manager(neuron_inputs) as inputs: outputs = self.model(*inputs) # shape: (batch_size, sequencen_len, vocab_size) outputs = self.remove_padding( outputs, dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]] ) # Remove padding on batch_size(0), and sequence_length(1) logits = outputs[0] return MaskedLMOutput(logits=logits) @add_start_docstrings( """ Bert with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, NEURON_MODEL_START_DOCSTRING, ) class NeuronBertForQuestionAnswering(NeuronTracedModel): auto_model_class = AutoModelForQuestionAnswering @add_start_docstrings_to_model_forward( NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + NEURON_QUESTION_ANSWERING_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="NeuronBertForQuestionAnswering", checkpoint="optimum/bert-base-cased-squad2-neuronx", ) ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ): neuron_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: neuron_inputs["token_type_ids"] = token_type_ids with self.neuron_padding_manager(neuron_inputs) as inputs: outputs = self.model(*inputs) # shape: [batch_size, sequence_length] outputs = self.remove_padding( outputs, dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]] ) # Remove padding on batch_size(0), and sequence_length(1) start_logits = outputs[0] end_logits = outputs[1] return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) @add_start_docstrings( """ Neuron Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, NEURON_MODEL_START_DOCSTRING, ) class NeuronBertForSequenceClassification(NeuronTracedModel): auto_model_class = AutoModelForSequenceClassification @add_start_docstrings_to_model_forward( NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + NEURON_SEQUENCE_CLASSIFICATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="NeuronBertForSequenceClassification", checkpoint="optimum/bert-base-multilingual-uncased-sentiment-neuronx", ) ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ): neuron_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: neuron_inputs["token_type_ids"] = token_type_ids with self.neuron_padding_manager(neuron_inputs) as inputs: outputs = self.model(*inputs) # shape: [batch_size, num_labels] outputs = self.remove_padding( outputs, dims=[0], indices=[input_ids.shape[0]] ) # Remove padding on batch_size(0) logits = outputs[0] return SequenceClassifierOutput(logits=logits) @add_start_docstrings( """ Neuron Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, NEURON_MODEL_START_DOCSTRING, ) class NeuronBertForTokenClassification(NeuronTracedModel): auto_model_class = AutoModelForTokenClassification @add_start_docstrings_to_model_forward( NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + NEURON_TOKEN_CLASSIFICATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="NeuronBertForTokenClassification", checkpoint="optimum/bert-base-NER-neuronx", ) ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ): neuron_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: neuron_inputs["token_type_ids"] = token_type_ids # run inference with self.neuron_padding_manager(neuron_inputs) as inputs: outputs = self.model(*inputs) # shape: [batch_size, sequence_length, num_labels] outputs = self.remove_padding( outputs, dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]] ) # Remove padding on batch_size(0), and sequence_length(-1) logits = outputs[0] return TokenClassifierOutput(logits=logits) @add_start_docstrings( """ Neuron Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, NEURON_MODEL_START_DOCSTRING, ) class NeuronBertForMultipleChoice(NeuronTracedModel): auto_model_class = AutoModelForMultipleChoice @add_start_docstrings_to_model_forward( NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + NEURON_MULTIPLE_CHOICE_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="NeuronBertForMultipleChoice", checkpoint="optimum/bert-base-cased-swag-neuronx", ) ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ): neuron_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: neuron_inputs["token_type_ids"] = token_type_ids # run inference with self.neuron_padding_manager(neuron_inputs) as inputs: outputs = self.model(*inputs) # shape: [batch_size, num_choices] outputs = self.remove_padding( outputs, dims=[0, -1], indices=[input_ids.shape[0], input_ids.shape[1]] ) # Remove padding on batch_size(0), and num_choices(-1) logits = outputs[0] return MultipleChoiceModelOutput(logits=logits)