optimum/graphcore/models/t5/modeling_t5.py (569 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Tuple, Union
import poptorch
import torch
import torch.nn as nn
from torch import Tensor
from transformers import T5ForConditionalGeneration
from transformers.activations import NewGELUActivation
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5Attention, T5Block, T5EncoderModel, T5Stack
from optimum.utils import logging
from ...generation import IPUAttentionMixin, IPUGenerationMixin, supports_kv_cache
from ...modeling_utils import (
PipelineMixin,
SerializedLinear,
SharedEmbedding,
get_layer_ipu,
recomputation_checkpoint,
register,
split_encoder_decoder_ipu_config,
)
logger = logging.get_logger(__name__)
class UpCastWrapper(nn.Module):
def __init__(self, module: nn.Module, scale: float = 1.0):
super().__init__()
self.module = module
self.scale = scale
def forward(self, input):
return self.module(input).to(torch.float32) * self.scale
class CustomGELU(NewGELUActivation):
# Work-around bug with torch.nn.GELU(approximate="tanh")
# TODO: Remove this when bug is fixed
def forward(self, input: Tensor) -> Tensor:
safe = torch.logical_and(-39 < input, input < 39)
safe_input = torch.where(safe, input, 0.0)
gelu = super().forward(safe_input)
relu = nn.functional.relu(input)
return torch.where(safe, gelu, relu)
class IPUT5Attention(T5Attention, IPUAttentionMixin):
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
# On the IPU the real sequence length is the padded sequence. If self attention
# kv caching is enabled, this length can be obtained from the kv cache.
# for cross kv caching computing the relative attention bias is disabled
# so we do not need to be aware of the decoder max length
if self.kv_cache_initialized:
real_seq_length = self._k_cache.shape[-2]
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
key_states = shape(self.k(hidden_states))
value_states = shape(self.v(hidden_states))
elif not self.cross_kv_cache_initialized:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
key_states = shape(self.k(key_value_states))
value_states = shape(self.v(key_value_states))
if self.kv_cache_initialized or self.cross_kv_cache_initialized:
# Change: remove branch to support prefix tuning
# This requires the IPU on device to cache to be aware of
# the prefix tokens
if key_value_states is None:
# caching key states for self attention
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
tgt_len = key_states.shape[-2]
if tgt_len != 1:
raise ValueError(f"KV cache expects tgt_len = 1, received {tgt_len}.")
key_states, value_states = self.add_to_kv_cache(key_states, value_states)
else:
# cached cross-attn
key_states, value_states = self.add_to_cross_kv_cache(
key_value_states,
lambda x: shape(self.k(x)),
lambda x: shape(self.v(x)),
)
# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
if self.cross_kv_cache_initialized:
raise NotImplementedError(
f"Cross KV caching with {self.has_relative_attention_bias=} is not yet supported on the IPU."
)
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
# if key and values are already calculated
# we want only the last query position bias
if self.kv_cache_initialized:
position_bias = poptorch.dynamic_slice(position_bias, 2, self._generation_step, 1, 1)
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
class CustomT5Block(T5Block):
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
# Custom: Remove check for inf
if hidden_states.dtype == torch.float16:
clamp_value = torch.tensor(torch.finfo(hidden_states.dtype).max - 1000, dtype=hidden_states.dtype)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
# Custom: Remove check for inf
if hidden_states.dtype == torch.float16:
clamp_value = torch.tensor(torch.finfo(hidden_states.dtype).max - 1000, dtype=hidden_states.dtype)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
# Custom: Remove check for inf
if hidden_states.dtype == torch.float16:
clamp_value = torch.tensor(torch.finfo(hidden_states.dtype).max - 1000, dtype=hidden_states.dtype)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class CustomT5Stack(T5Stack):
def invert_attention_mask(self, *args, **kwargs) -> Tensor:
return super().invert_attention_mask(*args, **kwargs) * 0.75
def get_extended_attention_mask(self, *args, **kwargs) -> Tensor:
return super().get_extended_attention_mask(*args, **kwargs) * 0.75
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
"""
Intercept the forward call in order to provide the correct attention mask if self-attention kv caching is enabled.
The alternative is to replicate the parent forward call here so that we can prevent the (default) construction of an attention mask
when kv caching is enabled. This would allow the attention layers to make the call `IPUT5Attention.update_attention_mask` to create
an attention mask with the knowledge of the decoder max length. To avoid replicating all but a few lines of code the former option
is kept.
"""
if self.is_decoder and self.block[0].layer[0].SelfAttention.kv_cache_initialized:
if attention_mask is None:
attention_layer = self.block[0].layer[0].SelfAttention
bsz, _, src_len, _ = attention_layer._k_cache.shape
attention_mask = torch.ones((1, src_len))
mask_cond = torch.arange(src_len).view(1, src_len)
attention_mask.masked_fill_(mask_cond >= attention_layer._generation_step + 1, 0)
attention_mask = attention_mask.to(attention_layer._k_cache.dtype)
attention_mask = attention_mask.expand(bsz, 1, src_len)
else:
raise ValueError(
f"Providing an {attention_mask=} to the decoder when kv-caching is enabled is currently not supported."
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@supports_kv_cache
@register(T5ForConditionalGeneration)
class PipelinedT5ForConditionalGeneration(T5ForConditionalGeneration, PipelineMixin, IPUGenerationMixin):
@property
def is_encoder_and_decoder_embeddings_computation_shared(self):
return isinstance(self.shared, SharedEmbedding)
def encoder_and_decoder_embeddings_computation(self, use_shared_embedding: bool):
"""Sets the T5ForConditionalGeneration shared embedding layer to SharedEmbedding that combines the computation under one layer.
Args:
use_shared_embedding: whether to use SharedEmbedding or not.
"""
if use_shared_embedding:
if isinstance(self.shared, SharedEmbedding):
logger.warning("encoder and decoder embeddings computation is already shared")
else:
self.shared = SharedEmbedding(self.shared)
else:
if isinstance(self.shared, nn.Embedding):
logger.warning("encoder and decoder embeddings computation is not shared")
else:
self.shared = self.shared.shared
def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=False, **kwargs):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the shared embedding with a SerializedEmbedding
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedT5ForConditionalGeneration(config).parallelize().half()
```
"""
PipelineMixin.parallelize(self)
if use_cache:
kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
if self.ipu_config.embedding_serialization_factor > 1:
self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor)
# TODO: is it needed to check?
if self.config.tie_word_embeddings:
self.tie_weights()
self.change_lm_head_to_indexed_input_linear(restore=not (for_generation and not use_cache))
self.encoder_and_decoder_embeddings_computation(True)
self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0)
# Use a custom T5Stack implementation because sharing the position bias causes OOM error
self.encoder.__class__ = CustomT5Stack
self.decoder.__class__ = CustomT5Stack
# Optimisations for generation
self.change_attention_class(
restore=False,
use_cache=use_cache and for_generation,
use_cross_cache=use_cross_cache and for_generation,
**kwargs,
)
self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))
# Upcast input embeddings so that the residuals remain in FP32. This
# cast is reversed where necessary by the T5LayerNorm layers in:
# - first layer of T5LayerSelfAttention
# - first layer of T5LayerFF
# - final_layer_norm
# Which, conveniently, are all the places that this needs to happen.
# Therefore, so we just need to upcast immediately before the residual
# adds in T5LayerSelfAttention and T5LayerFF. This is handled in the
# for loop below.
self.encoder.embed_tokens = UpCastWrapper(self.encoder.embed_tokens)
# Use a custom T5Block implementation that removes a dynamic if blocks that can't be statically traced
for block in self.encoder.block:
block.__class__ = CustomT5Block
# Dropout happens immediately before the residual add. Inserting a
# cast in T5LayerSelfAttention and T5LayerFF keeps the residual
# structure in FP32
block.layer[0].dropout = UpCastWrapper(block.layer[0].dropout)
# Scale down the weights for the T5LayerFF down-projection and
# then scale its output back up again after it is cast to FP32
scale = 8.0
with torch.no_grad():
block.layer[1].DenseReluDense.wo.weight /= scale
block.layer[1].dropout = UpCastWrapper(block.layer[1].dropout, scale)
# Prevent overflow in NewGELUActivation
if self.config.dense_act_fn == "gelu_new":
# TODO: Work-around bug with torch.nn.GELU(approximate="tanh"). Replace
# this with block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh")
# when bug is fixed
block.layer[1].DenseReluDense.act = CustomGELU()
for block in self.decoder.block:
block.__class__ = CustomT5Block
# Work-around bug with torch.nn.GELU(approximate="tanh")
# TODO: Remove this when bug is fixed
if self.config.dense_act_fn == "gelu_new":
block.layer[2].DenseReluDense.act = CustomGELU()
num_encoder_layers = len(self.encoder.block)
num_decoder_layers = len(self.decoder.block)
if for_generation:
# If running for text generation we split the IPU config into two configs
# because we run the encoder and decoder as separate Poplar executors.
ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
else:
number_of_layers = num_encoder_layers + num_decoder_layers
layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
encoder_layer_ipu = layer_ipu[:num_encoder_layers]
decoder_layer_ipu = layer_ipu[num_encoder_layers:]
for index, (layer, ipu) in enumerate(zip(self.encoder.block, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
self.encoder.final_layer_norm = poptorch.BeginBlock(
self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu
)
for index, (layer, ipu) in enumerate(zip(self.decoder.block, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")
self.decoder.final_layer_norm = poptorch.BeginBlock(
self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu
)
logger.info("LM Head Output --> IPU 0")
self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers.T5ForConditionalGeneration`.
"""
# T5ForConditionalGeneration has a deparallelize method, so make sure that the PipelineMixin one is used here.
PipelineMixin.deparallelize(self)
self.encoder_and_decoder_embeddings_computation(False)
self.set_on_device_generation_steps(0)
self.change_attention_class(restore=True)
self.encoder.__class__ = T5Stack
self.decoder.__class__ = T5Stack
self.encoder.embed_tokens = self.encoder.embed_tokens.module
for block in self.encoder.block:
block.__class__ = T5Block
block.layer[0].dropout = block.layer[0].dropout.module
with torch.no_grad():
block.layer[1].DenseReluDense.wo.weight *= block.layer[1].dropout.scale
block.layer[1].dropout = block.layer[1].dropout.module
if self.config.dense_act_fn == "gelu_new":
block.layer[1].DenseReluDense.act = NewGELUActivation()
for block in self.decoder.block:
block.__class__ = T5Block
if self.config.dense_act_fn == "gelu_new":
block.layer[2].DenseReluDense.act = NewGELUActivation()
self.change_lm_head_to_indexed_input_linear(restore=True)
if isinstance(self.lm_head, SerializedLinear):
self.lm_head = self.lm_head.to_model()
# TODO: is it needed to check?
if self.config.tie_word_embeddings:
self.tie_weights()
return self
def change_attention_class(self, restore=False, **kwargs):
"""Changes the attention layers to either use the original T5Attention forward
or IPUT5Attention forward.
Args:
restore (bool, optional): whether to restore the attention layers to their original version or not. Defaults to False.
"""
use_cache = kwargs.get("use_cache", False)
use_cross_cache = kwargs.get("use_cross_cache", False)
batch_size = kwargs.get("batch_size", 1)
max_length = kwargs.get("max_length", 128)
encoder_max_length = kwargs.get("encoder_max_length", 1500)
num_beams = kwargs.get("num_beams", 1)
for layer in self.encoder.block:
if restore:
layer.layer[0].SelfAttention = layer.layer[0].SelfAttention.to_model(T5Attention)
continue
layer.layer[0].SelfAttention = IPUT5Attention.from_model(
layer.layer[0].SelfAttention,
use_cache=False,
)
for layer in self.decoder.block:
if restore:
layer.layer[0].SelfAttention = layer.layer[0].SelfAttention.to_model(T5Attention)
layer.layer[1].EncDecAttention = layer.layer[1].EncDecAttention.to_model(T5Attention)
continue
layer.layer[0].SelfAttention = IPUT5Attention.from_model(
layer.layer[0].SelfAttention,
use_cache=use_cache,
use_cross_cache=False,
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
num_heads=layer.layer[0].SelfAttention.n_heads,
head_dim=layer.layer[0].SelfAttention.key_value_proj_dim,
dtype=layer.layer[0].SelfAttention.k.weight.dtype,
)
layer.layer[1].EncDecAttention = IPUT5Attention.from_model(
layer.layer[1].EncDecAttention,
use_cache=False,
use_cross_cache=use_cross_cache,
batch_size=batch_size,
encoder_max_length=encoder_max_length,
num_beams=num_beams,
num_heads=layer.layer[1].EncDecAttention.n_heads,
head_dim=layer.layer[1].EncDecAttention.key_value_proj_dim,
dtype=layer.layer[1].EncDecAttention.k.weight.dtype,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# We don't use `past_key_values` for KV caching, and rely on `use_cache` instead.
beam_idx = None
if use_cache:
# cut decoder_input_ids if past is used
input_ids = input_ids[:, -1:]
beam_idx = kwargs.get("beam_idx", torch.arange(input_ids.shape[0], dtype=torch.long))
return {
"decoder_input_ids": input_ids,
"past_key_values": None,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
"beam_idx": beam_idx,
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.is_encoder_and_decoder_embeddings_computation_shared:
inputs_embeds, decoder_inputs_embeds = self.shared(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
)
if inputs_embeds is not None:
input_ids = None
if decoder_inputs_embeds is not None:
decoder_input_ids = None
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_scale_modifier = getattr(self, "lm_scale_modifier", None)
if lm_scale_modifier is not None:
sequence_output = sequence_output * lm_scale_modifier
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
# Only returning the loss to make the communication between the host and the device faster.
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return (loss,) if labels is not None else output
if loss is not None:
return Seq2SeqLMOutput(
loss=loss,
)
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@register(T5EncoderModel)
class PipelinedT5EncoderModel(T5EncoderModel, PipelineMixin):
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedT5EncoderModel(config).parallelize().half()
```
"""
PipelineMixin.parallelize(self)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0)
# Use a custom T5Stack implementation because sharing the position bias causes OOM error
self.encoder.__class__ = CustomT5Stack
# Upcast input embeddings so that the residuals remain in FP32. This
# cast is reversed where necessary by the T5LayerNorm layers in:
# - first layer of T5LayerSelfAttention
# - first layer of T5LayerFF
# - final_layer_norm
# Which, conveniently, are all the places that this needs to happen.
# Therefore, so we just need to upcast immediately before the residual
# adds in T5LayerSelfAttention and T5LayerFF. This is handled in the
# for loop below.
self.encoder.embed_tokens = UpCastWrapper(self.encoder.embed_tokens)
# Use a custom T5Block implementation that removes a dynamic if blocks that can't be statically traced
for block in self.encoder.block:
block.__class__ = CustomT5Block
# Dropout happens immediately before the residual add. Inserting a
# cast in T5LayerSelfAttention and T5LayerFF keeps the residual
# structure in FP32
block.layer[0].dropout = UpCastWrapper(block.layer[0].dropout)
# Scale down the weights for the T5LayerFF down-projection and
# then scale its output back up again after it is cast to FP32
scale = 8.0
with torch.no_grad():
block.layer[1].DenseReluDense.wo.weight /= scale
block.layer[1].dropout = UpCastWrapper(block.layer[1].dropout, scale)
# Prevent overflow in NewGELUActivation
if self.config.dense_act_fn == "gelu_new":
# TODO: Work-around bug with torch.nn.GELU(approximate="tanh"). Replace
# this with block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh")
# when bug is fixed
block.layer[1].DenseReluDense.act = CustomGELU()
num_encoder_layers = len(self.encoder.block)
number_of_layers = num_encoder_layers
encoder_layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
for index, (layer, ipu) in enumerate(zip(self.encoder.block, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
self.encoder.final_layer_norm = poptorch.BeginBlock(
self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu
)
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers.T5ForConditionalGeneration`.
"""
# T5ForConditionalGeneration has a deparallelize method, so make sure that the PipelineMixin one is used here.
PipelineMixin.deparallelize(self)
self.encoder.__class__ = T5Stack
self.encoder.embed_tokens = self.encoder.embed_tokens.module
for block in self.encoder.block:
block.__class__ = T5Block
block.layer[0].dropout = block.layer[0].dropout.module
with torch.no_grad():
block.layer[1].DenseReluDense.wo.weight *= block.layer[1].dropout.scale
block.layer[1].dropout = block.layer[1].dropout.module
if self.config.dense_act_fn == "gelu_new":
block.layer[1].DenseReluDense.act = NewGELUActivation()
return self