optimum/graphcore/models/bart/modeling_bart.py (807 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import random
from typing import List, Optional, Tuple, Union
import poptorch
import torch
import torch.nn as nn
from transformers import BartForConditionalGeneration, BartForSequenceClassification, BartModel
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqSequenceClassifierOutput,
)
from transformers.models.bart.modeling_bart import (
BartAttention,
BartDecoder,
BartEncoder,
BartEncoderLayer,
BartLearnedPositionalEmbedding,
)
from optimum.utils import logging
from ...generation import IPUAttentionMixin, IPUGenerationMixin, supports_kv_cache
from ...modeling_utils import (
PipelineMixin,
SerializedLinear,
SharedEmbedding,
get_layer_ipu,
recomputation_checkpoint,
register,
shift_tokens_right,
split_encoder_decoder_ipu_config,
)
logger = logging.get_logger(__name__)
FLOAT16_LIMIT = 1e4
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""Makes causal mask used for bi-directional self-attention.
This differs from the original implementation by:
- Making the mask creation simpler in terms of operations used
- Changing the value for tokens to mask to something compatible with fp16
- Not expanding the final mask to [bsz, 1, tgt_len, src_len]
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), -FLOAT16_LIMIT, dtype=dtype)
mask = torch.triu(mask, diagonal=1).to(dtype=dtype)
return mask[None, None, :, :]
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, 1, src_seq_len]`.
This differs from the original implementation by:
- Changing the value for tokens to mask to something compatible with fp16
- Not expanding the final mask to [bsz, 1, tgt_len, src_len]
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :]
inverted_mask = 1.0 - expanded_mask
# Using FLOAT16_LIMIT instead of -float("inf") to avoid NaNs on the IPUs.
inverted_mask = -FLOAT16_LIMIT * inverted_mask
return inverted_mask.to(dtype)
class IPUBartAttention(BartAttention, IPUAttentionMixin):
"""The same as BartAttention without the attention mask shape check.
This is needed because the original BartAttention checks that the attention mask shape is [bs, 1, tgt_len, src_len]
but the pipelined implementation does not expand the mask, it just inserts dimensions, the shape is then
[bs, 1, 1, src_len], and broadcasting does the rest.
"""
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
if key_value_states is not None:
# cross attention
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif self.kv_cache_initialized:
# self attention with kv cache
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if tgt_len != 1:
raise ValueError(f"KV cache expects tgt_len = 1, received {tgt_len}.")
key_states, value_states = self.add_to_kv_cache(key_states, value_states)
attention_mask = self.update_attention_mask(attention_mask)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if attention_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class _BartEncoderLayerNoClamp(BartEncoderLayer):
"""
Same as BartEncoderLayer except it removed the dynamic if statement
for clamping fp16 tensor values.
"""
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
layer_head_mask: torch.FloatTensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
# Change: removing this `if` because it can't be statically compiled.
# if hidden_states.dtype == torch.float16 and (
# torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
# ):
# clamp_value = torch.finfo(hidden_states.dtype).max - 1000
# hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class _BartEncoderWithCustomExpandMask(BartEncoder):
"""The same as BartEncoder but uses a custom version of _expand_mask.
Check the _expand_mask docstring for more information.
"""
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=False,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input = input_ids
input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None:
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class _BartDecoderWithCustomMakeCausalAndExpandMask(BartDecoder):
"""The same as BartDecoder but uses a custom version of _make_causal_mask and _expand_mask.
Check the _expand_mask docstring for more information.
"""
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions
positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
pass
# if attn_mask.size()[0] != (len(self.layers)):
# raise ValueError(
# "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
# )
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class IPUBartLearnedPositionalEmbedding(BartLearnedPositionalEmbedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
@classmethod
def from_model(cls, model: BartLearnedPositionalEmbedding):
clone = copy.deepcopy(model)
clone.__class__ = cls
clone.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False)
return clone
def to_model(self) -> BartLearnedPositionalEmbedding:
del self._generation_step
original = copy.deepcopy(self)
original.__class__ = BartLearnedPositionalEmbedding
return original
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
if input_ids.shape[-1] == 1:
# KV cache enabled.
del past_key_values_length
return torch.index_select(self.weight, 0, self._generation_step + self.offset)
else:
return super().forward(input_ids, past_key_values_length)
class _BartModelWithSharedEmbedding(BartModel):
@property
def is_encoder_and_decoder_embeddings_computation_shared(self):
return isinstance(self.shared, SharedEmbedding)
def encoder_and_decoder_embeddings_computation(self, use_shared_embedding: bool):
"""Sets the BartModel shared embedding layer to SharedEmbedding that combines the computation under one layer.
Args:
use_shared_embedding: whether to use SharedEmbedding or not.
"""
if use_shared_embedding:
if isinstance(self.shared, SharedEmbedding):
logger.warning("encoder and decoder embeddings computation is already shared")
else:
self.shared = SharedEmbedding(self.shared)
else:
if isinstance(self.shared, nn.Embedding):
logger.warning("encoder and decoder embeddings computation is not shared")
else:
self.shared = self.shared.shared
def change_bart_encoder_and_decoder_classes(self, restore: bool):
"""Changes the encoder and decoder classes to update their forward pass so that they use our custom versions of
_make_causal_mask and _expand_mask.
Args:
restore: whether to restore the encoder and decoder to their original version or not.
"""
self.encoder.__class__ = BartEncoder if restore else _BartEncoderWithCustomExpandMask
self.decoder.__class__ = BartDecoder if restore else _BartDecoderWithCustomMakeCausalAndExpandMask
for layer in self.encoder.layers:
layer.__class__ = BartEncoderLayer if restore else _BartEncoderLayerNoClamp
def change_bart_attention_class(self, restore: bool, **kwargs):
"""Changes the attention layers to either use the original BartAttention forward or
BartAttentionWithoutException forward.
Args:
restore: whether to restore the attention layers to their original version or not.
"""
use_cache = kwargs.get("use_cache", False)
batch_size = kwargs.get("batch_size", 1)
max_length = kwargs.get("max_length", 128)
num_beams = kwargs.get("num_beams", 1)
for encoder_layer in self.encoder.layers:
if restore:
encoder_layer.self_attn = encoder_layer.self_attn.to_model(BartAttention)
continue
encoder_layer.self_attn = IPUBartAttention.from_model(
encoder_layer.self_attn,
use_cache=False,
)
for decoder_layer in self.decoder.layers:
if restore:
decoder_layer.self_attn = decoder_layer.self_attn.to_model(BartAttention)
decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(BartAttention)
continue
decoder_layer.self_attn = IPUBartAttention.from_model(
decoder_layer.self_attn,
use_cache=use_cache,
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
dtype=decoder_layer.self_attn.k_proj.weight.dtype,
)
decoder_layer.encoder_attn = IPUBartAttention.from_model(
decoder_layer.encoder_attn,
use_cache=False,
)
def change_decoder_positional_embedding(self, restore: bool):
"""Changes the decoder positional embedding to support an optional static KV cache.
Args:
restore: whether to restore the decoder positional embedding to their original version or not.
"""
position_embedding = self.decoder.embed_positions
self.decoder.embed_positions = (
position_embedding.to_model()
if restore
else IPUBartLearnedPositionalEmbedding.from_model(position_embedding)
)
def quantize_linear_layers(self, restore: bool, num_groups: int = 16):
if restore:
return
from ...quantization.group_quantize import GroupQuantLinear
logger.info("Group quantizing linear layers")
for module in self.encoder.layers:
module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups)
module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups)
module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups)
module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups)
module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups)
module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups)
for module in self.decoder.layers:
module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups)
module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups)
module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups)
module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups)
module.encoder_attn.q_proj = GroupQuantLinear.from_model(module.encoder_attn.q_proj, num_groups)
module.encoder_attn.k_proj = GroupQuantLinear.from_model(module.encoder_attn.k_proj, num_groups)
module.encoder_attn.v_proj = GroupQuantLinear.from_model(module.encoder_attn.v_proj, num_groups)
module.encoder_attn.out_proj = GroupQuantLinear.from_model(module.encoder_attn.out_proj, num_groups)
module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups)
module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# different to other models, Bart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided
if decoder_input_ids is None and decoder_inputs_embeds is None:
if input_ids is None:
raise ValueError(
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
"passed, `input_ids` cannot be `None`. Please pass either "
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
)
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
if self.is_encoder_and_decoder_embeddings_computation_shared:
inputs_embeds, decoder_inputs_embeds = self.shared(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
encoder_embed_scale=self.encoder.embed_scale,
decoder_embed_scale=self.decoder.embed_scale,
)
if inputs_embeds is not None:
input_ids = None
if decoder_inputs_embeds is not None:
decoder_input_ids = None
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@supports_kv_cache
@register(BartForConditionalGeneration)
class PipelinedBartForConditionalGeneration(BartForConditionalGeneration, PipelineMixin, IPUGenerationMixin):
def parallelize(self, for_generation=False, use_cache=False, **kwargs):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the shared embedding with a SerializedEmbedding
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedBartForConditionalGeneration(config).parallelize().half()
```
"""
super().parallelize()
if use_cache:
kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
if self.ipu_config.embedding_serialization_factor > 1:
self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor)
self.tie_weights()
self.model.__class__ = _BartModelWithSharedEmbedding
self.model.encoder_and_decoder_embeddings_computation(use_shared_embedding=True)
self.model.change_bart_encoder_and_decoder_classes(restore=False)
self.model.change_bart_attention_class(restore=False, use_cache=use_cache and for_generation, **kwargs)
self.model.change_decoder_positional_embedding(restore=False)
self.change_lm_head_to_indexed_input_linear(restore=not (for_generation and not use_cache))
self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))
self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16)
self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0)
self.model.encoder.embed_positions = poptorch.BeginBlock(
self.model.encoder.embed_positions, "Embedding", ipu_id=0
)
self.model.encoder.layernorm_embedding = poptorch.BeginBlock(
self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0
)
num_encoder_layers = len(self.model.encoder.layers)
num_decoder_layers = len(self.model.decoder.layers)
if for_generation:
# If running for text generation we split the IPU config into two configs
# because we run the encoder and decoder as separate Poplar executors.
ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
else:
number_of_layers = num_encoder_layers + num_decoder_layers
layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
encoder_layer_ipu = layer_ipu[:num_encoder_layers]
decoder_layer_ipu = layer_ipu[num_encoder_layers:]
for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
self.model.decoder.embed_positions = poptorch.BeginBlock(
self.model.decoder.embed_positions, "Embedding", ipu_id=0
)
self.model.decoder.layernorm_embedding = poptorch.BeginBlock(
self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0
)
for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")
logger.info("LM Head Output --> IPU 0")
self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers.BartForConditionalGeneration`.
"""
super().deparallelize()
self.model.encoder_and_decoder_embeddings_computation(False)
self.model.change_bart_encoder_and_decoder_classes(True)
self.model.change_bart_attention_class(True)
self.model.change_decoder_positional_embedding(restore=True)
self.model.__class__ = BartModel
self.change_lm_head_to_indexed_input_linear(restore=True)
self.set_on_device_generation_steps(0)
if isinstance(self.lm_head, SerializedLinear):
self.lm_head = self.lm_head.to_model()
self.tie_weights()
return self
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
**kwargs,
):
# We don't use `past_key_values` for KV caching, and rely on `use_cache` instead.
beam_idx = None
if use_cache:
decoder_input_ids = decoder_input_ids[:, -1:]
beam_idx = kwargs.get("beam_idx", torch.arange(decoder_input_ids.shape[0], dtype=torch.long))
return {
"encoder_outputs": encoder_outputs,
"past_key_values": None,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"beam_idx": beam_idx,
}
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
outputs = super().forward(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.training:
# Only returning the loss to make the communication between the host and the device faster.
if not return_dict:
return outputs[0:1]
else:
return Seq2SeqLMOutput(loss=outputs.loss)
else:
return outputs
@register(BartForSequenceClassification)
class PipelinedBartForSequenceClassification(BartForSequenceClassification, PipelineMixin):
def parallelize(self, **kwargs):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedBartForSequenceClassification(config).parallelize().half()
```
"""
super().parallelize()
self.model.__class__ = _BartModelWithSharedEmbedding
self.model.encoder_and_decoder_embeddings_computation(use_shared_embedding=True)
self.model.change_bart_encoder_and_decoder_classes(restore=False)
self.model.change_bart_attention_class(restore=False)
self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0)
self.model.encoder.embed_positions = poptorch.BeginBlock(
self.model.encoder.embed_positions, "Embedding", ipu_id=0
)
self.model.encoder.layernorm_embedding = poptorch.BeginBlock(
self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0
)
number_of_layers = len(self.model.encoder.layers) + len(self.model.decoder.layers)
layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
for index, layer in enumerate(self.model.encoder.layers):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
self.model.decoder.embed_positions = poptorch.BeginBlock(
self.model.decoder.embed_positions, "Embedding", ipu_id=0
)
self.model.decoder.layernorm_embedding = poptorch.BeginBlock(
self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0
)
shift = len(self.model.encoder.layers)
for index, layer in enumerate(self.model.decoder.layers):
ipu = layer_ipu[index + shift]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")
last_ipu = layer_ipu[-1]
logger.info(f"Classification Head Output --> IPU {last_ipu}")
self.classification_head = poptorch.BeginBlock(
self.classification_head, "Classification Head Output", ipu_id=last_ipu
)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers.BartForSequenceClassification`.
"""
super().deparallelize()
self.model.encoder_and_decoder_embeddings_computation(False)
self.model.change_bart_encoder_and_decoder_classes(True)
self.model.change_bart_attention_class(True)
self.model.__class__ = BartModel
return self
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0] # last hidden state
B, L, E = hidden_states.shape
eos_mask = torch.eq(input_ids, self.config.eos_token_id)
# Static tensor shape version of hidden_states[eos_mask, :]
eos_indices = eos_mask * torch.arange(L).unsqueeze(0)
last_eos_index, _ = torch.max(eos_indices, dim=1)
# torch.index_select requires a 1D tensor of indices
last_eos_index += torch.arange(B) * L
hidden_states = hidden_states.view(B * L, E)
sentence_representation = torch.index_select(hidden_states, 0, last_eos_index)
logits = self.classification_head(sentence_representation)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)