optimum/habana/transformers/models/speecht5/modeling_speecht5.py (384 lines of code) (raw):
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.fsdp import is_fsdp_managed_module
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet, SpeechT5PreTrainedModel
from transformers.utils import logging
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
logger = logging.get_logger(__name__)
def gaudi_SpeechT5Attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
output_attentions: bool = False,
token_idx: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from SpeechT5Attention.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py
The only differences are:
- add new args token_idx
"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if token_idx is not None:
past_key_value[0].index_copy_(2, token_idx - 1, key_states)
past_key_value[1].index_copy_(2, token_idx - 1, value_states)
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# relative attention bias
if position_bias is not None:
reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1)
rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
rel_pos_bias = rel_pos_bias.transpose(0, 1).view(
bsz * self.num_heads, position_bias.size(0), position_bias.size(1)
)
attn_weights += rel_pos_bias
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
def gaudi_SpeechT5DecoderLayer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
token_idx: Optional[torch.Tensor] = None,
):
"""
Copied from SpeechT5DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py
The only differences are:
- add token_idx in self-attention
"""
residual = hidden_states
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
token_idx=token_idx,
)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
def gaudi_SpeechT5Decoder_forward(
self,
hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
"""
Copied from SpeechT5Decoder.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py
The only differences are:
- add token_idx args
- use _gaudi_prepare_4d_causal_attention_mask
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_shape = hidden_states.size()[:-1]
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
attention_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
)
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
skip_the_layer = False
if self.training:
dropout_probability = torch.rand([])
skip_the_layer = dropout_probability < self.layerdrop
if skip_the_layer and not synced_gpus:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
def gaudi_generate_speech(
model: SpeechT5PreTrainedModel,
input_values: torch.FloatTensor,
speaker_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
return_output_lengths: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
"""
Copied from _generate_speech: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py
The only differences are:
- add hpu graph wrap
- add static shape support in kv-cache in _generate_speech
- disable speech_decoder_prenet_dropout to avoid variable output length
"""
if speaker_embeddings is None:
raise ValueError(
"""`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
the code snippet provided in this link:
https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
"""
)
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
if not hasattr(model.speecht5.encoder, "clear_cache"):
model.speecht5.encoder = wrap_in_hpu_graph(model.speecht5.encoder)
if not hasattr(model.speecht5.decoder.wrapped_decoder, "clear_cache"):
model.speecht5.decoder.wrapped_decoder = wrap_in_hpu_graph(model.speecht5.decoder.wrapped_decoder)
if not hasattr(model.speecht5.decoder.prenet, "clear_cache"):
model.speecht5.decoder.prenet = wrap_in_hpu_graph(model.speecht5.decoder.prenet)
if attention_mask is None:
encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int()
else:
encoder_attention_mask = attention_mask
bsz = input_values.size(0)
encoder_out = model.speecht5.encoder(
input_values=input_values,
attention_mask=encoder_attention_mask,
return_dict=True,
)
encoder_last_hidden_state = encoder_out.last_hidden_state
# downsample encoder attention mask
if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet):
encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask(
encoder_out[0].shape[1], encoder_attention_mask
)
maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor)
minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor)
# Start the output sequence with a mel spectrum that is all zeros.
output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins)
output_sequence = torch.nn.functional.pad(output_sequence, (0, 0, 0, maxlen - 1), value=model.config.pad_token_id)
spectrogram = []
cross_attentions = []
past_key_values = None
idx = 0
result_spectrogram = {}
token_idx = torch.tensor(1, device=output_sequence.device)
attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device)
while True:
idx += 1
attention_mask.index_fill_(1, token_idx - 1, 1)
# Run the decoder prenet on the entire output sequence.
decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
# Run the decoder layers on the last element of the prenet output.
decoder_out = model.speecht5.decoder.wrapped_decoder(
hidden_states=decoder_hidden_states
if past_key_values is None
else torch.index_select(decoder_hidden_states, 1, token_idx - 1),
attention_mask=attention_mask,
encoder_hidden_states=encoder_last_hidden_state,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=True,
output_attentions=output_cross_attentions,
return_dict=True,
token_idx=token_idx,
)
if output_cross_attentions:
cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
last_decoder_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1)
past_key_values = decoder_out.past_key_values
# Predict the new mel spectrum for this step in the sequence.
spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
spectrogram.append(spectrum)
output_sequence.index_copy_(1, token_idx, spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins))
# Predict the probability that this is the stop token.
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
token_idx.add_(1)
# Finished when stop token or maximum length is reached.
if idx < minlen:
continue
else:
# If the generation loop is less than maximum length time, check the ones in the batch that have met
# the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
if idx < maxlen:
meet_thresholds = torch.sum(prob, dim=-1) >= threshold
meet_indexes = torch.where(meet_thresholds)[0].tolist()
else:
meet_indexes = range(len(prob))
meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]
if len(meet_indexes) > 0:
spectrograms = torch.stack(spectrogram)
spectrograms = spectrograms.transpose(0, 1).flatten(1, 2)
spectrograms = model.speech_decoder_postnet.postnet(spectrograms)
for meet_index in meet_indexes:
result_spectrogram[meet_index] = spectrograms[meet_index]
if len(result_spectrogram) >= bsz:
break
spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
if not return_output_lengths:
spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
if vocoder is not None:
outputs = vocoder(spectrogram)
else:
outputs = spectrogram
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
if bsz > 1:
cross_attentions = cross_attentions.view(
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
)
outputs = (outputs, cross_attentions)
else:
# batched return values should also include the spectrogram/waveform lengths
spectrogram_lengths = []
for i in range(bsz):
spectrogram_lengths.append(spectrograms[i].size(0))
if vocoder is None:
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
outputs = (spectrograms, spectrogram_lengths)
else:
waveforms = []
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
outputs = (waveforms, waveform_lengths)
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
cross_attentions = cross_attentions.view(
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
)
outputs = (*outputs, cross_attentions)
return outputs