optimum/habana/transformers/models/speecht5/modeling_speecht5.py (384 lines of code) (raw):

from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet, SpeechT5PreTrainedModel from transformers.utils import logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) logger = logging.get_logger(__name__) def gaudi_SpeechT5Attention_forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from SpeechT5Attention.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add new args token_idx """ # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if token_idx is not None: past_key_value[0].index_copy_(2, token_idx - 1, key_states) past_key_value[1].index_copy_(2, token_idx - 1, value_states) key_states = past_key_value[0] value_states = past_key_value[1] else: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) # relative attention bias if position_bias is not None: reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1) rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) rel_pos_bias = rel_pos_bias.transpose(0, 1).view( bsz * self.num_heads, position_bias.size(0), position_bias.size(1) ) attn_weights += rel_pos_bias if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value def gaudi_SpeechT5DecoderLayer_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, token_idx: Optional[torch.Tensor] = None, ): """ Copied from SpeechT5DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add token_idx in self-attention """ residual = hidden_states # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, token_idx=token_idx, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs def gaudi_SpeechT5Decoder_forward( self, hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: """ Copied from SpeechT5Decoder.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add token_idx args - use _gaudi_prepare_4d_causal_attention_mask """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = hidden_states.size()[:-1] past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] ) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) skip_the_layer = False if self.training: dropout_probability = torch.rand([]) skip_the_layer = dropout_probability < self.layerdrop if skip_the_layer and not synced_gpus: continue past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) def gaudi_generate_speech( model: SpeechT5PreTrainedModel, input_values: torch.FloatTensor, speaker_embeddings: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, return_output_lengths: bool = False, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: """ Copied from _generate_speech: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/speecht5/modeling_speecht5.py The only differences are: - add hpu graph wrap - add static shape support in kv-cache in _generate_speech - disable speech_decoder_prenet_dropout to avoid variable output length """ if speaker_embeddings is None: raise ValueError( """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following the code snippet provided in this link: https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors """ ) from habana_frameworks.torch.hpu import wrap_in_hpu_graph if not hasattr(model.speecht5.encoder, "clear_cache"): model.speecht5.encoder = wrap_in_hpu_graph(model.speecht5.encoder) if not hasattr(model.speecht5.decoder.wrapped_decoder, "clear_cache"): model.speecht5.decoder.wrapped_decoder = wrap_in_hpu_graph(model.speecht5.decoder.wrapped_decoder) if not hasattr(model.speecht5.decoder.prenet, "clear_cache"): model.speecht5.decoder.prenet = wrap_in_hpu_graph(model.speecht5.decoder.prenet) if attention_mask is None: encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() else: encoder_attention_mask = attention_mask bsz = input_values.size(0) encoder_out = model.speecht5.encoder( input_values=input_values, attention_mask=encoder_attention_mask, return_dict=True, ) encoder_last_hidden_state = encoder_out.last_hidden_state # downsample encoder attention mask if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( encoder_out[0].shape[1], encoder_attention_mask ) maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor) minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) # Start the output sequence with a mel spectrum that is all zeros. output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) output_sequence = torch.nn.functional.pad(output_sequence, (0, 0, 0, maxlen - 1), value=model.config.pad_token_id) spectrogram = [] cross_attentions = [] past_key_values = None idx = 0 result_spectrogram = {} token_idx = torch.tensor(1, device=output_sequence.device) attention_mask = torch.zeros((bsz, maxlen), dtype=torch.long, device=output_sequence.device) while True: idx += 1 attention_mask.index_fill_(1, token_idx - 1, 1) # Run the decoder prenet on the entire output sequence. decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) # Run the decoder layers on the last element of the prenet output. decoder_out = model.speecht5.decoder.wrapped_decoder( hidden_states=decoder_hidden_states if past_key_values is None else torch.index_select(decoder_hidden_states, 1, token_idx - 1), attention_mask=attention_mask, encoder_hidden_states=encoder_last_hidden_state, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=True, output_attentions=output_cross_attentions, return_dict=True, token_idx=token_idx, ) if output_cross_attentions: cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) last_decoder_output = decoder_out.last_hidden_state[:, 0:1, :].squeeze(1) past_key_values = decoder_out.past_key_values # Predict the new mel spectrum for this step in the sequence. spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) spectrogram.append(spectrum) output_sequence.index_copy_(1, token_idx, spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)) # Predict the probability that this is the stop token. prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) token_idx.add_(1) # Finished when stop token or maximum length is reached. if idx < minlen: continue else: # If the generation loop is less than maximum length time, check the ones in the batch that have met # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. if idx < maxlen: meet_thresholds = torch.sum(prob, dim=-1) >= threshold meet_indexes = torch.where(meet_thresholds)[0].tolist() else: meet_indexes = range(len(prob)) meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] if len(meet_indexes) > 0: spectrograms = torch.stack(spectrogram) spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) spectrograms = model.speech_decoder_postnet.postnet(spectrograms) for meet_index in meet_indexes: result_spectrogram[meet_index] = spectrograms[meet_index] if len(result_spectrogram) >= bsz: break spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] if not return_output_lengths: spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) if vocoder is not None: outputs = vocoder(spectrogram) else: outputs = spectrogram if output_cross_attentions: cross_attentions = torch.cat(cross_attentions, dim=2) if bsz > 1: cross_attentions = cross_attentions.view( bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] ) outputs = (outputs, cross_attentions) else: # batched return values should also include the spectrogram/waveform lengths spectrogram_lengths = [] for i in range(bsz): spectrogram_lengths.append(spectrograms[i].size(0)) if vocoder is None: spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) outputs = (spectrograms, spectrogram_lengths) else: waveforms = [] spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) waveforms = vocoder(spectrograms) waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] outputs = (waveforms, waveform_lengths) if output_cross_attentions: cross_attentions = torch.cat(cross_attentions, dim=2) cross_attentions = cross_attentions.view( bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] ) outputs = (*outputs, cross_attentions) return outputs