optimum/graphcore/models/whisper/modeling_whisper.py (511 lines of code) (raw):
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Optional, Tuple, Union
import poptorch
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperEncoder,
WhisperEncoderLayer,
WhisperForConditionalGeneration,
WhisperPositionalEmbedding,
)
from optimum.utils import logging
from ...generation import IPUAttentionMixin, IPUGenerationMixin, assert_poptorch_supports_cond, supports_kv_cache
from ...modeling_utils import (
PipelineMixin,
SerializedLinear,
get_layer_ipu,
recomputation_checkpoint,
register,
shift_tokens_right,
split_encoder_decoder_ipu_config,
)
logger = logging.get_logger(__name__)
FLOAT16_LIMIT = 1e4
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), -FLOAT16_LIMIT)
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), -FLOAT16_LIMIT)
class IPUWhisperAttention(WhisperAttention, IPUAttentionMixin):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, tgt_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) * self.scaling
if key_value_states is not None:
if self.cross_kv_cache_initialized:
# cross attention with cross kv cache
key_states, value_states = self.add_to_cross_kv_cache(
key_value_states,
lambda x: self._shape(self.k_proj(x), -1, bsz),
lambda x: self._shape(self.v_proj(x), -1, bsz),
)
else:
# cross attention
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif self.kv_cache_initialized:
# self attention with kv cache
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if tgt_len != 1:
raise ValueError(f"KV cache expects tgt_len = 1, received {tgt_len}.")
key_states, value_states = self.add_to_kv_cache(key_states, value_states)
attention_mask = self.update_attention_mask(attention_mask)
else:
# self attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# We handle the KV cache via buffers, not via the eager approach of passing the cache around.
# This is retained so upstream DecoderLayer can stay as is and that tests pass.
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
# Change: optionally serialize attention, mainly intended for the encoder with large sequence length.
if self.is_attention_serialized:
if layer_head_mask is not None:
raise ValueError("layer_head_mask is not supported yet with serialized attention.")
if self.dropout and self.training:
raise ValueError("dropout is not supported yet with serialized attention.")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attention_mask = attention_mask.view(bsz, tgt_len, src_len).repeat(self.num_heads, 1, 1)
attn_output = self.serialized_attention(query_states, key_states, value_states, 1.0, attention_mask)
else:
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# Change: delete optional reshaping of attn_weights
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
# Change: modified check for output_attentions
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class _WhisperEncoderLayerClamp(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# NOTE: This differs from the original implementation
# There is a type mismatch bug with this call to clamp so we remove it here. It is anyway not needed on IPU because FP16 values are clamped to maximum value by default.
# TODO: when bug is fixed in future SDK remove this entire class;
# clamp_value = torch.finfo(hidden_states.dtype).max - 1000
# hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class IPUWhisperPositionalEmbedding(WhisperPositionalEmbedding):
@classmethod
def from_model(cls, model: WhisperPositionalEmbedding):
clone = copy.deepcopy(model)
clone.__class__ = cls
clone.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False)
return clone
def to_model(self) -> WhisperPositionalEmbedding:
del self._generation_step
original = copy.deepcopy(self)
original.__class__ = WhisperPositionalEmbedding
return original
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
if input_ids.shape[-1] == 1:
# KV cache enabled.
del past_key_values_length
return torch.index_select(self.weight, 0, self._generation_step)
else:
return super().forward(input_ids, past_key_values_length=past_key_values_length)
class _WhisperDecoderWithCustomMakeCausalAndExpandMask(WhisperDecoder):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]
Args:
config: WhisperConfig
"""
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
class IPUWhisperConditionalEncoder(WhisperEncoder):
@classmethod
def from_model(cls, model: WhisperEncoder, batch_size: int, num_beams: int):
clone = model
clone.__class__ = cls
clone.register_buffer(
"_encoder_last_hidden_state",
torch.zeros((batch_size, model.config.max_source_positions, model.config.d_model), dtype=model.dtype),
persistent=False,
)
clone.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False)
clone._batch_size = batch_size
clone._num_beams = num_beams
return clone
def to_model(self) -> WhisperEncoder:
del self._encoder_last_hidden_state
del self._generation_step
del self._batch_size
del self._num_beams
original = self
original.__class__ = WhisperEncoder
return original
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if attention_mask is not None or head_mask is not None or output_attentions or output_hidden_states:
raise ValueError(f"{self.__class__.__name__} only accepts `input_features`.")
def run_encoder(input_features):
encoder_output = WhisperEncoder.forward(self, input_features, return_dict=True)
return encoder_output.last_hidden_state
def skip_encoder(input_features):
return self._encoder_last_hidden_state
self._encoder_last_hidden_state.copy_(
poptorch.cond(self._generation_step == 0, run_encoder, [input_features], skip_encoder, [input_features])[0]
)
last_hidden_state = self._encoder_last_hidden_state
if self._num_beams > 1:
# Before being passed to the decoder, we must expand the encoder outputs when beam search is enabled
# as this would be done on host.
last_hidden_state = last_hidden_state.repeat_interleave(
self._num_beams, dim=0, output_size=self._batch_size * self._num_beams
)
return BaseModelOutput(last_hidden_state=last_hidden_state)
@supports_kv_cache
@register(WhisperForConditionalGeneration)
class PipelinedWhisperForConditionalGeneration(WhisperForConditionalGeneration, PipelineMixin, IPUGenerationMixin):
def change_encoder_layer_class(self, restore: bool):
"""Changes the encoder layer class to avoid the dynamic 'if'
Args:
restore: whether to restore the encoder layers to their original version or not.
"""
for layer in self.model.encoder.layers:
layer.__class__ = WhisperEncoderLayer if restore else _WhisperEncoderLayerClamp
def change_encoder_class(self, restore: bool, **kwargs):
"""Changes the encoder class to run the encoder under a `poptorch.cond` op.
Args:
restore: whether to restore the encoder to its original version or not.
"""
batch_size = kwargs.get("batch_size", 1)
num_beams = kwargs.get("num_beams", 1)
encoder = self.model.get_encoder()
if restore:
if isinstance(encoder, IPUWhisperConditionalEncoder):
self.model.encoder = encoder.to_model()
else:
if self.ipu_config.inference_ipus_per_replica > 1:
raise ValueError(
f"`{self.ipu_config.inference_ipus_per_replica=}` should be 1 when placing encoder and decoder on the same IPU."
)
assert_poptorch_supports_cond(
context="Whisper encoder is being conditionally run on the same IPU as the decoder since `use_cond_encoder=True`."
)
self.model.encoder = IPUWhisperConditionalEncoder.from_model(encoder, batch_size, num_beams)
def change_decoder_class(self, restore: bool):
"""Changes the decoder class to update the _prepare_decoder_attention_mask method.
Args:
restore: whether to restore the decoder to its original version or not.
"""
self.model.decoder.__class__ = WhisperDecoder if restore else _WhisperDecoderWithCustomMakeCausalAndExpandMask
def change_decoder_positional_embedding(self, restore: bool):
"""Changes the decoder positional embedding to support an optional static KV cache.
Args:
restore: whether to restore the decoder positional embedding to their original version or not.
"""
position_embedding = self.model.decoder.embed_positions
self.model.decoder.embed_positions = (
position_embedding.to_model() if restore else IPUWhisperPositionalEmbedding.from_model(position_embedding)
)
def change_attention_class(self, restore=False, **kwargs):
"""Change the attention layers to support a KV cache.
Args:
restore: whether to restore the attention layers to their original version or not.
"""
batch_size = kwargs.get("batch_size", 1)
num_beams = kwargs.get("num_beams", 1)
use_cache = kwargs.get("use_cache", False)
max_length = kwargs.get("max_length", 448)
use_cross_cache = kwargs.get("use_cross_cache", False)
encoder_max_length = kwargs.get("encoder_max_length", 1500)
batch_serialization_factor = kwargs.get("batch_serialization_factor", 1)
sequence_serialization_factor = kwargs.get("sequence_serialization_factor", 1)
for encoder_layer in self.model.encoder.layers:
if restore:
encoder_layer.self_attn = encoder_layer.self_attn.to_model(WhisperAttention)
continue
encoder_layer.self_attn = IPUWhisperAttention.from_model(
encoder_layer.self_attn,
use_cache=False,
batch_serialization_factor=batch_serialization_factor,
sequence_serialization_factor=sequence_serialization_factor,
)
for decoder_layer in self.model.decoder.layers:
if restore:
decoder_layer.self_attn = decoder_layer.self_attn.to_model(WhisperAttention)
decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(WhisperAttention)
continue
decoder_layer.self_attn = IPUWhisperAttention.from_model(
decoder_layer.self_attn,
use_cache=use_cache,
use_cross_cache=False,
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
dtype=decoder_layer.self_attn.k_proj.weight.dtype,
)
decoder_layer.encoder_attn = IPUWhisperAttention.from_model(
decoder_layer.encoder_attn,
use_cache=False,
use_cross_cache=use_cross_cache,
batch_size=batch_size,
encoder_max_length=encoder_max_length,
num_beams=num_beams,
dtype=decoder_layer.encoder_attn.k_proj.weight.dtype,
)
def change_lm_head(self, restore: bool, use_cache: bool = None):
# Maybe use _IndexedInputLinear
self.change_lm_head_to_indexed_input_linear(restore or use_cache)
# Maybe use SerializedLinear
if restore:
lm_head = self.get_output_embeddings()
if isinstance(lm_head, SerializedLinear):
self.set_output_embeddings(lm_head.to_model())
self.tie_weights()
else:
projection_serialization_factor = max(
self.ipu_config._projection_serialization_factor or 1,
sum(self.ipu_config._serialized_projection_splits_per_ipu or [1]),
)
if projection_serialization_factor > 1:
self.set_output_embeddings(
SerializedLinear.from_model(self.get_output_embeddings(), projection_serialization_factor)
)
self.tie_weights()
def quantize_linear_layers(self, restore: bool, num_groups: int = 16):
if not restore:
from ...quantization.group_quantize import GroupQuantLinear
logger.info("Group quantizing linear layers")
for module in self.model.encoder.layers:
module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups)
module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups)
module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups)
module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups)
module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups)
module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups)
for module in self.model.decoder.layers:
module.self_attn.q_proj = GroupQuantLinear.from_model(module.self_attn.q_proj, num_groups)
module.self_attn.k_proj = GroupQuantLinear.from_model(module.self_attn.k_proj, num_groups)
module.self_attn.v_proj = GroupQuantLinear.from_model(module.self_attn.v_proj, num_groups)
module.self_attn.out_proj = GroupQuantLinear.from_model(module.self_attn.out_proj, num_groups)
module.encoder_attn.q_proj = GroupQuantLinear.from_model(module.encoder_attn.q_proj, num_groups)
module.encoder_attn.k_proj = GroupQuantLinear.from_model(module.encoder_attn.k_proj, num_groups)
module.encoder_attn.v_proj = GroupQuantLinear.from_model(module.encoder_attn.v_proj, num_groups)
module.encoder_attn.out_proj = GroupQuantLinear.from_model(module.encoder_attn.out_proj, num_groups)
module.fc1 = GroupQuantLinear.from_model(module.fc1, num_groups)
module.fc2 = GroupQuantLinear.from_model(module.fc2, num_groups)
def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=False, **kwargs):
super().parallelize()
if use_cache:
kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)
self._use_cond_encoder = kwargs.get("use_cond_encoder", False)
self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
if self._use_cond_encoder and self._use_encoder_output_buffer:
raise ValueError(
"`use_cond_encoder=True` is incompatible with `use_encoder_output_buffer=True`, only set one to True."
)
self._use_group_quantized_linears = kwargs.get("use_group_quantized_linears", False)
self.change_encoder_layer_class(restore=False)
self.change_decoder_class(restore=False)
self.change_decoder_positional_embedding(restore=False)
self.change_attention_class(
restore=False,
use_cache=use_cache and for_generation,
use_cross_cache=use_cross_cache and for_generation,
**kwargs,
)
self.change_lm_head(restore=False, use_cache=use_cache or not for_generation)
self.change_encoder_class(restore=not self._use_cond_encoder, **kwargs)
self.quantize_linear_layers(restore=not self._use_group_quantized_linears, num_groups=16)
self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))
logger.info("---------- Device Allocation -----------")
logger.info("conv1, conv2, embed_positions --> IPU 0")
self.model.encoder.conv1 = poptorch.BeginBlock(self.model.encoder.conv1, "Conv1", ipu_id=0)
self.model.encoder.conv2 = poptorch.BeginBlock(self.model.encoder.conv2, "Conv2", ipu_id=0)
self.model.encoder.embed_positions = poptorch.BeginBlock(
self.model.encoder.embed_positions, "Embed Positions", ipu_id=0
)
num_encoder_layers = len(self.model.encoder.layers)
num_decoder_layers = len(self.model.decoder.layers)
if for_generation and not self._use_cond_encoder:
# If running for text generation (and the encoder and decoder are run as separate Poplar executors)
# we split the IPU config into two configs.
ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
else:
number_of_layers = num_encoder_layers + num_decoder_layers
layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
encoder_layer_ipu = layer_ipu[:num_encoder_layers]
decoder_layer_ipu = layer_ipu[num_encoder_layers:]
for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
# we need to deal with the model.encoder.layer norm
self.model.encoder.layer_norm = poptorch.BeginBlock(
self.model.encoder.layer_norm, "Encoder Layer Norm", ipu_id=ipu
)
logger.info(f"Encoder LN --> IPU {ipu}")
decoder_embedding_ipu = decoder_layer_ipu[0]
if (serialized_projection_splits_per_ipu := self.ipu_config._serialized_projection_splits_per_ipu) is not None:
serialized_projection_ipus = [i for i, x in enumerate(serialized_projection_splits_per_ipu) if x]
if len(serialized_projection_ipus) > 1:
# This is because we are using SerializedLinear. All splits of a SerializedLinear layer must be on the
# same IPU. We are using SerializedLinear instead of SplitLinear because we must tie the weights, which
# cannot be done when using SplitLinear.
raise ValueError(
"`serialized_projection_splits_per_ipu` must only have 1 non-zero element for Whisper."
)
decoder_embedding_ipu = serialized_projection_ipus[0]
self.model.decoder.embed_tokens = poptorch.BeginBlock(
self.model.decoder.embed_tokens, "Decoder Embedding", ipu_id=decoder_embedding_ipu
)
logger.info(f"Decoder Embedding --> IPU {decoder_embedding_ipu}")
prev_ipu = decoder_layer_ipu[0]
for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
if ipu != prev_ipu:
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
prev_ipu = ipu
logger.info(f"Decoder {index:<2} --> IPU {ipu}")
self.model.decoder.layer_norm = poptorch.BeginBlock(
self.model.decoder.layer_norm, "Decoder Layer Norm", ipu_id=ipu
)
logger.info(f"Head --> IPU {decoder_embedding_ipu}")
logger.info("---------------------------------------")
self.proj_out = poptorch.BeginBlock(self.proj_out, "Output Projection", ipu_id=decoder_embedding_ipu)
return self
def deparallelize(self):
super().deparallelize()
self.change_encoder_layer_class(restore=True)
self.change_decoder_class(restore=True)
self.change_decoder_positional_embedding(restore=True)
self.change_attention_class(restore=True)
self.change_lm_head(restore=True)
self.change_encoder_class(restore=True)
self.set_on_device_generation_steps(0)
return self
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
**kwargs,
):
# We don't use `past_key_values` for KV caching, and rely on `use_cache` instead.
beam_idx = None
if use_cache:
decoder_input_ids = decoder_input_ids[:, -1:]
beam_idx = kwargs.get("beam_idx", torch.arange(decoder_input_ids.shape[0], dtype=torch.long))
ret = {
"past_key_values": None,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": None,
"beam_idx": beam_idx,
}
if self.cond_encoder_enabled:
input_features = kwargs.get("input_features", None)
if input_features is None:
raise ValueError("Missing `input_features` with `use_cond_encoder=True`.")
ret["input_features"] = input_features
else:
ret["encoder_outputs"] = encoder_outputs
return ret
# TODO: consider making such output subsetting a decorator
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
# Duplicate this portion of upstream logic so we can intercept the call to `shift_tokens_right`.
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
output = super().forward(
input_features=input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
# Minimize IO and only return loss when training.
return Seq2SeqLMOutput(
loss=output.loss,
logits=None if self.training else output.logits,
encoder_last_hidden_state=None if self.training else output.encoder_last_hidden_state, # for tests to pass
past_key_values=None if self.training else output.past_key_values, # for tests to pass
)