optimum/graphcore/models/whisper/modeling_whisper.py (511 lines of code) (raw):

# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. # Copyright (c) 2023 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Optional, Tuple, Union import poptorch import torch from torch import nn from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, WhisperForConditionalGeneration, WhisperPositionalEmbedding, ) from optimum.utils import logging from ...generation import IPUAttentionMixin, IPUGenerationMixin, assert_poptorch_supports_cond, supports_kv_cache from ...modeling_utils import ( PipelineMixin, SerializedLinear, get_layer_ipu, recomputation_checkpoint, register, shift_tokens_right, split_encoder_decoder_ipu_config, ) logger = logging.get_logger(__name__) FLOAT16_LIMIT = 1e4 # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), -FLOAT16_LIMIT) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), -FLOAT16_LIMIT) class IPUWhisperAttention(WhisperAttention, IPUAttentionMixin): def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) * self.scaling if key_value_states is not None: if self.cross_kv_cache_initialized: # cross attention with cross kv cache key_states, value_states = self.add_to_cross_kv_cache( key_value_states, lambda x: self._shape(self.k_proj(x), -1, bsz), lambda x: self._shape(self.v_proj(x), -1, bsz), ) else: # cross attention key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif self.kv_cache_initialized: # self attention with kv cache key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if tgt_len != 1: raise ValueError(f"KV cache expects tgt_len = 1, received {tgt_len}.") key_states, value_states = self.add_to_kv_cache(key_states, value_states) attention_mask = self.update_attention_mask(attention_mask) else: # self attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # We handle the KV cache via buffers, not via the eager approach of passing the cache around. # This is retained so upstream DecoderLayer can stay as is and that tests pass. past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) # Change: optionally serialize attention, mainly intended for the encoder with large sequence length. if self.is_attention_serialized: if layer_head_mask is not None: raise ValueError("layer_head_mask is not supported yet with serialized attention.") if self.dropout and self.training: raise ValueError("dropout is not supported yet with serialized attention.") if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attention_mask = attention_mask.view(bsz, tgt_len, src_len).repeat(self.num_heads, 1, 1) attn_output = self.serialized_attention(query_states, key_states, value_states, 1.0, attention_mask) else: attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( bsz, self.num_heads, tgt_len, src_len ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) # Change: delete optional reshaping of attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) # Change: modified check for output_attentions if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class _WhisperEncoderLayerClamp(nn.Module): def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, ): """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # NOTE: This differs from the original implementation # There is a type mismatch bug with this call to clamp so we remove it here. It is anyway not needed on IPU because FP16 values are clamped to maximum value by default. # TODO: when bug is fixed in future SDK remove this entire class; # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class IPUWhisperPositionalEmbedding(WhisperPositionalEmbedding): @classmethod def from_model(cls, model: WhisperPositionalEmbedding): clone = copy.deepcopy(model) clone.__class__ = cls clone.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False) return clone def to_model(self) -> WhisperPositionalEmbedding: del self._generation_step original = copy.deepcopy(self) original.__class__ = WhisperPositionalEmbedding return original def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): if input_ids.shape[-1] == 1: # KV cache enabled. del past_key_values_length return torch.index_select(self.weight, 0, self._generation_step) else: return super().forward(input_ids, past_key_values_length=past_key_values_length) class _WhisperDecoderWithCustomMakeCausalAndExpandMask(WhisperDecoder): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`] Args: config: WhisperConfig """ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask class IPUWhisperConditionalEncoder(WhisperEncoder): @classmethod def from_model(cls, model: WhisperEncoder, batch_size: int, num_beams: int): clone = model clone.__class__ = cls clone.register_buffer( "_encoder_last_hidden_state", torch.zeros((batch_size, model.config.max_source_positions, model.config.d_model), dtype=model.dtype), persistent=False, ) clone.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False) clone._batch_size = batch_size clone._num_beams = num_beams return clone def to_model(self) -> WhisperEncoder: del self._encoder_last_hidden_state del self._generation_step del self._batch_size del self._num_beams original = self original.__class__ = WhisperEncoder return original def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): if attention_mask is not None or head_mask is not None or output_attentions or output_hidden_states: raise ValueError(f"{self.__class__.__name__} only accepts `input_features`.") def run_encoder(input_features): encoder_output = WhisperEncoder.forward(self, input_features, return_dict=True) return encoder_output.last_hidden_state def skip_encoder(input_features): return self._encoder_last_hidden_state self._encoder_last_hidden_state.copy_( poptorch.cond(self._generation_step == 0, run_encoder, [input_features], skip_encoder, [input_features])[0] ) last_hidden_state = self._encoder_last_hidden_state if self._num_beams > 1: # Before being passed to the decoder, we must expand the encoder outputs when beam search is enabled # as this would be done on host. last_hidden_state = last_hidden_state.repeat_interleave( self._num_beams, dim=0, output_size=self._batch_size * self._num_beams ) return BaseModelOutput(last_hidden_state=last_hidden_state) @supports_kv_cache @register(WhisperForConditionalGeneration) class PipelinedWhisperForConditionalGeneration(WhisperForConditionalGeneration, PipelineMixin, IPUGenerationMixin): def change_encoder_layer_class(self, restore: bool): """Changes the encoder layer class to avoid the dynamic 'if' Args: restore: whether to restore the encoder layers to their original version or not. """ for layer in self.model.encoder.layers: layer.__class__ = WhisperEncoderLayer if restore else _WhisperEncoderLayerClamp def change_encoder_class(self, restore: bool, **kwargs): """Changes the encoder class to run the encoder under a `poptorch.cond` op. Args: restore: whether to restore the encoder to its original version or not. """ batch_size = kwargs.get("batch_size", 1) num_beams = kwargs.get("num_beams", 1) encoder = self.model.get_encoder() if restore: if isinstance(encoder, IPUWhisperConditionalEncoder): self.model.encoder = encoder.to_model() else: if self.ipu_config.inference_ipus_per_replica > 1: raise ValueError( f"`{self.ipu_config.inference_ipus_per_replica=}` should be 1 when placing encoder and decoder on the same IPU." ) assert_poptorch_supports_cond( context="Whisper encoder is being conditionally run on the same IPU as the decoder since `use_cond_encoder=True`." ) self.model.encoder = IPUWhisperConditionalEncoder.from_model(encoder, batch_size, num_beams) def change_decoder_class(self, restore: bool): """Changes the decoder class to update the _prepare_decoder_attention_mask method. Args: restore: whether to restore the decoder to its original version or not. """ self.model.decoder.__class__ = WhisperDecoder if restore else _WhisperDecoderWithCustomMakeCausalAndExpandMask def change_decoder_positional_embedding(self, restore: bool): """Changes the decoder positional embedding to support an optional static KV cache. Args: restore: whether to restore the decoder positional embedding to their original version or not. """ position_embedding = self.model.decoder.embed_positions self.model.decoder.embed_positions = ( position_embedding.to_model() if restore else IPUWhisperPositionalEmbedding.from_model(position_embedding) ) def change_attention_class(self, restore=False, **kwargs): """Change the attention layers to support a KV cache. Args: restore: whether to restore the attention layers to their original version or not. """ batch_size = kwargs.get("batch_size", 1) num_beams = kwargs.get("num_beams", 1) use_cache = kwargs.get("use_cache", False) max_length = kwargs.get("max_length", 448) use_cross_cache = kwargs.get("use_cross_cache", False) encoder_max_length = kwargs.get("encoder_max_length", 1500) batch_serialization_factor = kwargs.get("batch_serialization_factor", 1) sequence_serialization_factor = kwargs.get("sequence_serialization_factor", 1) for encoder_layer in self.model.encoder.layers: if restore: encoder_layer.self_attn = encoder_layer.self_attn.to_model(WhisperAttention) continue encoder_layer.self_attn = IPUWhisperAttention.from_model( encoder_layer.self_attn, use_cache=False, batch_serialization_factor=batch_serialization_factor, sequence_serialization_factor=sequence_serialization_factor, ) for decoder_layer in self.model.decoder.layers: if restore: decoder_layer.self_attn = decoder_layer.self_attn.to_model(WhisperAttention) decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(WhisperAttention) continue decoder_layer.self_attn = IPUWhisperAttention.from_model( decoder_layer.self_attn, use_cache=use_cache, use_cross_cache=False, batch_size=batch_size, max_length=max_length, num_beams=num_beams, dtype=decoder_layer.self_attn.k_proj.weight.dtype, ) decoder_layer.encoder_attn = IPUWhisperAttention.from_model( decoder_layer.encoder_attn, use_cache=False, use_cross_cache=use_cross_cache, batch_size=batch_size, encoder_max_length=encoder_max_length, num_beams=num_beams, dtype=decoder_layer.encoder_attn.k_proj.weight.dtype, ) def change_lm_head(self, restore: bool, use_cache: bool = None): # Maybe use _IndexedInputLinear self.change_lm_head_to_indexed_input_linear(restore or use_cache) # Maybe use SerializedLinear if restore: lm_head = self.get_output_embeddings() if isinstance(lm_head, SerializedLinear): self.set_output_embeddings(lm_head.to_model()) self.tie_weights() else: projection_serialization_factor = max( self.ipu_config._projection_serialization_factor or 1, sum(self.ipu_config._serialized_projection_splits_per_ipu or [1]), ) if projection_serialization_factor > 1: self.set_output_embeddings( SerializedLinear.from_model(self.get_output_embeddings(), projection_serialization_factor) ) self.tie_weights() def quantize_linear_layers(self, restore: bool, num_groups: int = 16): if not restore: from ...quantization.group_quantize import GroupQuantLinear logger.info("Group quantizing linear layers") for module in self.model.encoder.layers: module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups) module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups) module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups) module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups) module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups) module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups) for module in self.model.decoder.layers: module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups) module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups) module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups) module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups) module.encoder_attn.q_proj = GroupQuantLinear.from_model(module.encoder_attn.q_proj, num_groups) module.encoder_attn.k_proj = GroupQuantLinear.from_model(module.encoder_attn.k_proj, num_groups) module.encoder_attn.v_proj = GroupQuantLinear.from_model(module.encoder_attn.v_proj, num_groups) module.encoder_attn.out_proj = GroupQuantLinear.from_model(module.encoder_attn.out_proj, num_groups) module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups) module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups) def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=False, **kwargs): super().parallelize() if use_cache: kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs) self._use_cond_encoder = kwargs.get("use_cond_encoder", False) self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False) if self._use_cond_encoder and self._use_encoder_output_buffer: raise ValueError( "`use_cond_encoder=True` is incompatible with `use_encoder_output_buffer=True`, only set one to True." ) self._use_group_quantized_linears = kwargs.get("use_group_quantized_linears", False) self.change_encoder_layer_class(restore=False) self.change_decoder_class(restore=False) self.change_decoder_positional_embedding(restore=False) self.change_attention_class( restore=False, use_cache=use_cache and for_generation, use_cross_cache=use_cross_cache and for_generation, **kwargs, ) self.change_lm_head(restore=False, use_cache=use_cache or not for_generation) self.change_encoder_class(restore=not self._use_cond_encoder, **kwargs) self.quantize_linear_layers(restore=not self._use_group_quantized_linears, num_groups=16) self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0)) logger.info("---------- Device Allocation -----------") logger.info("conv1, conv2, embed_positions --> IPU 0") self.model.encoder.conv1 = poptorch.BeginBlock(self.model.encoder.conv1, "Conv1", ipu_id=0) self.model.encoder.conv2 = poptorch.BeginBlock(self.model.encoder.conv2, "Conv2", ipu_id=0) self.model.encoder.embed_positions = poptorch.BeginBlock( self.model.encoder.embed_positions, "Embed Positions", ipu_id=0 ) num_encoder_layers = len(self.model.encoder.layers) num_decoder_layers = len(self.model.decoder.layers) if for_generation and not self._use_cond_encoder: # If running for text generation (and the encoder and decoder are run as separate Poplar executors) # we split the IPU config into two configs. ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers) self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers) decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers) else: number_of_layers = num_encoder_layers + num_decoder_layers layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers) encoder_layer_ipu = layer_ipu[:num_encoder_layers] decoder_layer_ipu = layer_ipu[num_encoder_layers:] for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)): if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: self._hooks.append(recomputation_checkpoint(layer)) self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") # we need to deal with the model.encoder.layer norm self.model.encoder.layer_norm = poptorch.BeginBlock( self.model.encoder.layer_norm, "Encoder Layer Norm", ipu_id=ipu ) logger.info(f"Encoder LN --> IPU {ipu}") decoder_embedding_ipu = decoder_layer_ipu[0] if (serialized_projection_splits_per_ipu := self.ipu_config._serialized_projection_splits_per_ipu) is not None: serialized_projection_ipus = [i for i, x in enumerate(serialized_projection_splits_per_ipu) if x] if len(serialized_projection_ipus) > 1: # This is because we are using SerializedLinear. All splits of a SerializedLinear layer must be on the # same IPU. We are using SerializedLinear instead of SplitLinear because we must tie the weights, which # cannot be done when using SplitLinear. raise ValueError( "`serialized_projection_splits_per_ipu` must only have 1 non-zero element for Whisper." ) decoder_embedding_ipu = serialized_projection_ipus[0] self.model.decoder.embed_tokens = poptorch.BeginBlock( self.model.decoder.embed_tokens, "Decoder Embedding", ipu_id=decoder_embedding_ipu ) logger.info(f"Decoder Embedding --> IPU {decoder_embedding_ipu}") prev_ipu = decoder_layer_ipu[0] for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)): if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: self._hooks.append(recomputation_checkpoint(layer)) if ipu != prev_ipu: self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) prev_ipu = ipu logger.info(f"Decoder {index:<2} --> IPU {ipu}") self.model.decoder.layer_norm = poptorch.BeginBlock( self.model.decoder.layer_norm, "Decoder Layer Norm", ipu_id=ipu ) logger.info(f"Head --> IPU {decoder_embedding_ipu}") logger.info("---------------------------------------") self.proj_out = poptorch.BeginBlock(self.proj_out, "Output Projection", ipu_id=decoder_embedding_ipu) return self def deparallelize(self): super().deparallelize() self.change_encoder_layer_class(restore=True) self.change_decoder_class(restore=True) self.change_decoder_positional_embedding(restore=True) self.change_attention_class(restore=True) self.change_lm_head(restore=True) self.change_encoder_class(restore=True) self.set_on_device_generation_steps(0) return self def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, use_cache=None, encoder_outputs=None, attention_mask=None, **kwargs, ): # We don't use `past_key_values` for KV caching, and rely on `use_cache` instead. beam_idx = None if use_cache: decoder_input_ids = decoder_input_ids[:, -1:] beam_idx = kwargs.get("beam_idx", torch.arange(decoder_input_ids.shape[0], dtype=torch.long)) ret = { "past_key_values": None, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": None, "beam_idx": beam_idx, } if self.cond_encoder_enabled: input_features = kwargs.get("input_features", None) if input_features is None: raise ValueError("Missing `input_features` with `use_cond_encoder=True`.") ret["input_features"] = input_features else: ret["encoder_outputs"] = encoder_outputs return ret # TODO: consider making such output subsetting a decorator def forward( self, input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: # Duplicate this portion of upstream logic so we can intercept the call to `shift_tokens_right`. if labels is not None: if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) output = super().forward( input_features=input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) # Minimize IO and only return loss when training. return Seq2SeqLMOutput( loss=output.loss, logits=None if self.training else output.logits, encoder_last_hidden_state=None if self.training else output.encoder_last_hidden_state, # for tests to pass past_key_values=None if self.training else output.past_key_values, # for tests to pass )