optimum/onnx/modeling_seq2seq.py (44 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.file_utils import add_start_docstrings_to_model_forward DECODER_WITH_LM_HEAD_INPUTS_DOCSTRING = r""" Arguments: input_ids (`torch.LongTensor`): Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`. encoder_hidden_states (`torch.FloatTensor`): The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`. attention_mask (`torch.LongTensor`, *optional*): Mask to avoid performing attention on padding token indices of `input_ids`. encoder_attention_mask (`torch.LongTensor`, *optional*): Mask to avoid performing cross-attention on padding tokens indices of encoder `input_ids`. past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)` Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding. The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ # Currently inherits from PreTrainedModel for export constraint coming from transformers.onnx.export class _DecoderWithLMhead(PreTrainedModel): """ Decoder model with a language modeling head on top. Arguments: model (`transformers.PreTrainedModel`): The model from which to extract the decoder and the language modeling head. """ def __init__(self, model: PreTrainedModel): super().__init__(model.config) self.config = model.config self.decoder = model.get_decoder() self.lm_head = model.get_output_embeddings() self.final_logits_bias = getattr(model, "final_logits_bias", None) @add_start_docstrings_to_model_forward(DECODER_WITH_LM_HEAD_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor, encoder_hidden_states: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, ): decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values, return_dict=True, use_cache=True, ) last_hidden_state = decoder_outputs.last_hidden_state if self.config.model_type == "t5" and self.config.tie_word_embeddings: # T5 needs its output to be rescaled before projecting on vocab last_hidden_state = last_hidden_state * (self.config.d_model**-0.5) lm_logits = self.lm_head(last_hidden_state) # Add the final bias if present in the model if self.final_logits_bias is not None: lm_logits += self.final_logits_bias if labels is None: return lm_logits, decoder_outputs.past_key_values else: # Calculate loss loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) return loss, lm_logits, decoder_outputs.past_key_values