optimum/graphcore/models/bart/modeling_bart.py (807 lines of code) (raw):

# Copyright 2021 The HuggingFace Team. All rights reserved. # Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import random from typing import List, Optional, Tuple, Union import poptorch import torch import torch.nn as nn from transformers import BartForConditionalGeneration, BartForSequenceClassification, BartModel from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqSequenceClassifierOutput, ) from transformers.models.bart.modeling_bart import ( BartAttention, BartDecoder, BartEncoder, BartEncoderLayer, BartLearnedPositionalEmbedding, ) from optimum.utils import logging from ...generation import IPUAttentionMixin, IPUGenerationMixin, supports_kv_cache from ...modeling_utils import ( PipelineMixin, SerializedLinear, SharedEmbedding, get_layer_ipu, recomputation_checkpoint, register, shift_tokens_right, split_encoder_decoder_ipu_config, ) logger = logging.get_logger(__name__) FLOAT16_LIMIT = 1e4 def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): """Makes causal mask used for bi-directional self-attention. This differs from the original implementation by: - Making the mask creation simpler in terms of operations used - Changing the value for tokens to mask to something compatible with fp16 - Not expanding the final mask to [bsz, 1, tgt_len, src_len] """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), -FLOAT16_LIMIT, dtype=dtype) mask = torch.triu(mask, diagonal=1).to(dtype=dtype) return mask[None, None, :, :] def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, 1, src_seq_len]`. This differs from the original implementation by: - Changing the value for tokens to mask to something compatible with fp16 - Not expanding the final mask to [bsz, 1, tgt_len, src_len] """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :] inverted_mask = 1.0 - expanded_mask # Using FLOAT16_LIMIT instead of -float("inf") to avoid NaNs on the IPUs. inverted_mask = -FLOAT16_LIMIT * inverted_mask return inverted_mask.to(dtype) class IPUBartAttention(BartAttention, IPUAttentionMixin): """The same as BartAttention without the attention mask shape check. This is needed because the original BartAttention checks that the attention mask shape is [bs, 1, tgt_len, src_len] but the pipelined implementation does not expand the mask, it just inserts dimensions, the shape is then [bs, 1, 1, src_len], and broadcasting does the rest. """ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling if key_value_states is not None: # cross attention key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif self.kv_cache_initialized: # self attention with kv cache key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if tgt_len != 1: raise ValueError(f"KV cache expects tgt_len = 1, received {tgt_len}.") key_states, value_states = self.add_to_kv_cache(key_states, value_states) attention_mask = self.update_attention_mask(attention_mask) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ) if attention_mask is not None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class _BartEncoderLayerNoClamp(BartEncoderLayer): """ Same as BartEncoderLayer except it removed the dynamic if statement for clamping fp16 tensor values. """ def forward( self, hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) # Change: removing this `if` because it can't be statically compiled. # if hidden_states.dtype == torch.float16 and ( # torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() # ): # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class _BartEncoderWithCustomExpandMask(BartEncoder): """The same as BartEncoder but uses a custom version of _expand_mask. Check the _expand_mask docstring for more information. """ def forward( self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=False, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input = input_ids input_ids = input_ids.view(-1, input_ids.shape[-1]) elif inputs_embeds is not None: input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input) embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: if head_mask.size()[0] != (len(self.layers)): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." ) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class _BartDecoderWithCustomMakeCausalAndExpandMask(BartDecoder): """The same as BartDecoder but uses a custom version of _make_causal_mask and _expand_mask. Check the _expand_mask docstring for more information. """ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length ).to(self.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input = input_ids input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions positions = self.embed_positions(input, past_key_values_length) positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: pass # if attn_mask.size()[0] != (len(self.layers)): # raise ValueError( # "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." # ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, use_cache) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) class IPUBartLearnedPositionalEmbedding(BartLearnedPositionalEmbedding): """ This module learns positional embeddings up to a fixed maximum size. """ @classmethod def from_model(cls, model: BartLearnedPositionalEmbedding): clone = copy.deepcopy(model) clone.__class__ = cls clone.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False) return clone def to_model(self) -> BartLearnedPositionalEmbedding: del self._generation_step original = copy.deepcopy(self) original.__class__ = BartLearnedPositionalEmbedding return original def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): if input_ids.shape[-1] == 1: # KV cache enabled. del past_key_values_length return torch.index_select(self.weight, 0, self._generation_step + self.offset) else: return super().forward(input_ids, past_key_values_length) class _BartModelWithSharedEmbedding(BartModel): @property def is_encoder_and_decoder_embeddings_computation_shared(self): return isinstance(self.shared, SharedEmbedding) def encoder_and_decoder_embeddings_computation(self, use_shared_embedding: bool): """Sets the BartModel shared embedding layer to SharedEmbedding that combines the computation under one layer. Args: use_shared_embedding: whether to use SharedEmbedding or not. """ if use_shared_embedding: if isinstance(self.shared, SharedEmbedding): logger.warning("encoder and decoder embeddings computation is already shared") else: self.shared = SharedEmbedding(self.shared) else: if isinstance(self.shared, nn.Embedding): logger.warning("encoder and decoder embeddings computation is not shared") else: self.shared = self.shared.shared def change_bart_encoder_and_decoder_classes(self, restore: bool): """Changes the encoder and decoder classes to update their forward pass so that they use our custom versions of _make_causal_mask and _expand_mask. Args: restore: whether to restore the encoder and decoder to their original version or not. """ self.encoder.__class__ = BartEncoder if restore else _BartEncoderWithCustomExpandMask self.decoder.__class__ = BartDecoder if restore else _BartDecoderWithCustomMakeCausalAndExpandMask for layer in self.encoder.layers: layer.__class__ = BartEncoderLayer if restore else _BartEncoderLayerNoClamp def change_bart_attention_class(self, restore: bool, **kwargs): """Changes the attention layers to either use the original BartAttention forward or BartAttentionWithoutException forward. Args: restore: whether to restore the attention layers to their original version or not. """ use_cache = kwargs.get("use_cache", False) batch_size = kwargs.get("batch_size", 1) max_length = kwargs.get("max_length", 128) num_beams = kwargs.get("num_beams", 1) for encoder_layer in self.encoder.layers: if restore: encoder_layer.self_attn = encoder_layer.self_attn.to_model(BartAttention) continue encoder_layer.self_attn = IPUBartAttention.from_model( encoder_layer.self_attn, use_cache=False, ) for decoder_layer in self.decoder.layers: if restore: decoder_layer.self_attn = decoder_layer.self_attn.to_model(BartAttention) decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(BartAttention) continue decoder_layer.self_attn = IPUBartAttention.from_model( decoder_layer.self_attn, use_cache=use_cache, batch_size=batch_size, max_length=max_length, num_beams=num_beams, dtype=decoder_layer.self_attn.k_proj.weight.dtype, ) decoder_layer.encoder_attn = IPUBartAttention.from_model( decoder_layer.encoder_attn, use_cache=False, ) def change_decoder_positional_embedding(self, restore: bool): """Changes the decoder positional embedding to support an optional static KV cache. Args: restore: whether to restore the decoder positional embedding to their original version or not. """ position_embedding = self.decoder.embed_positions self.decoder.embed_positions = ( position_embedding.to_model() if restore else IPUBartLearnedPositionalEmbedding.from_model(position_embedding) ) def quantize_linear_layers(self, restore: bool, num_groups: int = 16): if restore: return from ...quantization.group_quantize import GroupQuantLinear logger.info("Group quantizing linear layers") for module in self.encoder.layers: module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups) module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups) module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups) module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups) module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups) module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups) for module in self.decoder.layers: module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups) module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups) module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups) module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups) module.encoder_attn.q_proj = GroupQuantLinear.from_model(module.encoder_attn.q_proj, num_groups) module.encoder_attn.k_proj = GroupQuantLinear.from_model(module.encoder_attn.k_proj, num_groups) module.encoder_attn.v_proj = GroupQuantLinear.from_model(module.encoder_attn.v_proj, num_groups) module.encoder_attn.out_proj = GroupQuantLinear.from_model(module.encoder_attn.out_proj, num_groups) module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups) module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups) def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # different to other models, Bart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: if input_ids is None: raise ValueError( "If no `decoder_input_ids` or `decoder_inputs_embeds` are " "passed, `input_ids` cannot be `None`. Please pass either " "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." ) decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: if self.is_encoder_and_decoder_embeddings_computation_shared: inputs_embeds, decoder_inputs_embeds = self.shared( input_ids=input_ids, decoder_input_ids=decoder_input_ids, encoder_embed_scale=self.encoder.embed_scale, decoder_embed_scale=self.decoder.embed_scale, ) if inputs_embeds is not None: input_ids = None if decoder_inputs_embeds is not None: decoder_input_ids = None encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) @supports_kv_cache @register(BartForConditionalGeneration) class PipelinedBartForConditionalGeneration(BartForConditionalGeneration, PipelineMixin, IPUGenerationMixin): def parallelize(self, for_generation=False, use_cache=False, **kwargs): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - (If enabled) Replaces the shared embedding with a SerializedEmbedding - Adds recomputation checkpoints Recommended usage: ``` model = PipelinedBartForConditionalGeneration(config).parallelize().half() ``` """ super().parallelize() if use_cache: kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs) logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") if self.ipu_config.embedding_serialization_factor > 1: self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor) self.tie_weights() self.model.__class__ = _BartModelWithSharedEmbedding self.model.encoder_and_decoder_embeddings_computation(use_shared_embedding=True) self.model.change_bart_encoder_and_decoder_classes(restore=False) self.model.change_bart_attention_class(restore=False, use_cache=use_cache and for_generation, **kwargs) self.model.change_decoder_positional_embedding(restore=False) self.change_lm_head_to_indexed_input_linear(restore=not (for_generation and not use_cache)) self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False) self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0)) self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16) self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0) self.model.encoder.embed_positions = poptorch.BeginBlock( self.model.encoder.embed_positions, "Embedding", ipu_id=0 ) self.model.encoder.layernorm_embedding = poptorch.BeginBlock( self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0 ) num_encoder_layers = len(self.model.encoder.layers) num_decoder_layers = len(self.model.decoder.layers) if for_generation: # If running for text generation we split the IPU config into two configs # because we run the encoder and decoder as separate Poplar executors. ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers) self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers) decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers) else: number_of_layers = num_encoder_layers + num_decoder_layers layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers) encoder_layer_ipu = layer_ipu[:num_encoder_layers] decoder_layer_ipu = layer_ipu[num_encoder_layers:] for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)): if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: self._hooks.append(recomputation_checkpoint(layer)) self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") self.model.decoder.embed_positions = poptorch.BeginBlock( self.model.decoder.embed_positions, "Embedding", ipu_id=0 ) self.model.decoder.layernorm_embedding = poptorch.BeginBlock( self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0 ) for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)): if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: self._hooks.append(recomputation_checkpoint(layer)) self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) logger.info(f"Decoder {index:<2} --> IPU {ipu}") logger.info("LM Head Output --> IPU 0") self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0) logger.info("-----------------------------------------------------------") return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is fully compatible with `transformers.BartForConditionalGeneration`. """ super().deparallelize() self.model.encoder_and_decoder_embeddings_computation(False) self.model.change_bart_encoder_and_decoder_classes(True) self.model.change_bart_attention_class(True) self.model.change_decoder_positional_embedding(restore=True) self.model.__class__ = BartModel self.change_lm_head_to_indexed_input_linear(restore=True) self.set_on_device_generation_steps(0) if isinstance(self.lm_head, SerializedLinear): self.lm_head = self.lm_head.to_model() self.tie_weights() return self def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, use_cache=None, encoder_outputs=None, attention_mask=None, decoder_attention_mask=None, **kwargs, ): # We don't use `past_key_values` for KV caching, and rely on `use_cache` instead. beam_idx = None if use_cache: decoder_input_ids = decoder_input_ids[:, -1:] beam_idx = kwargs.get("beam_idx", torch.arange(decoder_input_ids.shape[0], dtype=torch.long)) return { "encoder_outputs": encoder_outputs, "past_key_values": None, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, "beam_idx": beam_idx, } def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ outputs = super().forward( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.training: # Only returning the loss to make the communication between the host and the device faster. if not return_dict: return outputs[0:1] else: return Seq2SeqLMOutput(loss=outputs.loss) else: return outputs @register(BartForSequenceClassification) class PipelinedBartForSequenceClassification(BartForSequenceClassification, PipelineMixin): def parallelize(self, **kwargs): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Adds recomputation checkpoints Recommended usage: ``` model = PipelinedBartForSequenceClassification(config).parallelize().half() ``` """ super().parallelize() self.model.__class__ = _BartModelWithSharedEmbedding self.model.encoder_and_decoder_embeddings_computation(use_shared_embedding=True) self.model.change_bart_encoder_and_decoder_classes(restore=False) self.model.change_bart_attention_class(restore=False) self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16) logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0) self.model.encoder.embed_positions = poptorch.BeginBlock( self.model.encoder.embed_positions, "Embedding", ipu_id=0 ) self.model.encoder.layernorm_embedding = poptorch.BeginBlock( self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0 ) number_of_layers = len(self.model.encoder.layers) + len(self.model.decoder.layers) layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers) for index, layer in enumerate(self.model.encoder.layers): ipu = layer_ipu[index] if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: self._hooks.append(recomputation_checkpoint(layer)) self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") self.model.decoder.embed_positions = poptorch.BeginBlock( self.model.decoder.embed_positions, "Embedding", ipu_id=0 ) self.model.decoder.layernorm_embedding = poptorch.BeginBlock( self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0 ) shift = len(self.model.encoder.layers) for index, layer in enumerate(self.model.decoder.layers): ipu = layer_ipu[index + shift] if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: self._hooks.append(recomputation_checkpoint(layer)) self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) logger.info(f"Decoder {index:<2} --> IPU {ipu}") last_ipu = layer_ipu[-1] logger.info(f"Classification Head Output --> IPU {last_ipu}") self.classification_head = poptorch.BeginBlock( self.classification_head, "Classification Head Output", ipu_id=last_ipu ) logger.info("-----------------------------------------------------------") return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is fully compatible with `transformers.BartForSequenceClassification`. """ super().deparallelize() self.model.encoder_and_decoder_embeddings_computation(False) self.model.change_bart_encoder_and_decoder_classes(True) self.model.change_bart_attention_class(True) self.model.__class__ = BartModel return self def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] # last hidden state B, L, E = hidden_states.shape eos_mask = torch.eq(input_ids, self.config.eos_token_id) # Static tensor shape version of hidden_states[eos_mask, :] eos_indices = eos_mask * torch.arange(L).unsqueeze(0) last_eos_index, _ = torch.max(eos_indices, dim=1) # torch.index_select requires a 1D tensor of indices last_eos_index += torch.arange(B) * L hidden_states = hidden_states.view(B * L, E) sentence_representation = torch.index_select(hidden_states, 0, last_eos_index) logits = self.classification_head(sentence_representation) loss = None if labels is not None: if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.config.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return Seq2SeqSequenceClassifierOutput( loss=loss, logits=logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, )