optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py (626 lines of code) (raw):

from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from transformers.models.seamless_m4t.modeling_seamless_m4t import ( SeamlessM4TForTextToSpeech, SeamlessM4TGenerationOutput, _compute_new_attention_mask, format_speech_generation_kwargs, shift_tokens_right, ) from transformers.utils import logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) logger = logging.get_logger(__name__) def gaudi_SeamlessM4TAttention_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from SeamlessM4TAttention.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ # if encoder_hidden_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = encoder_hidden_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `encoder_hidden_states` to support prefix tuning if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if token_idx is None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: past_key_value[0].index_copy_(2, token_idx - 1, key_states) past_key_value[1].index_copy_(2, token_idx - 1, value_states) key_states = past_key_value[0] value_states = past_key_value[1] else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value def gaudi_SeamlessM4TDecoderLayer_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, token_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Copied from SeamlessM4TDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, token_idx=token_idx, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=cross_attn_past_key_value, attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # add cross-attn to positions 3,4 of present_key_value tuple present_key_value += cross_attn_present_key_value # Fully Connected residual = hidden_states hidden_states = self.ffn_layer_norm(hidden_states) hidden_states = self.ffn(hidden_states) hidden_states = self.ffn_dropout(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states, present_key_value) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs def gaudi_SeamlessM4TDecoder_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: """ Copied from SeamlessM4TDecoder.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input = input_ids input_shape = input.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # embed positions if past_key_values_length != 0 and token_idx is not None: past_key_values_length = token_idx - 1 positions = self.embed_positions(input, past_key_values_length=past_key_values_length) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, None, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[1],) if output_attentions: all_self_attns += (layer_outputs[2],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[3],) hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) def gaudi_SeamlessM4TTextToUnitModel_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: """ Copied from SeamlessM4TTextToUnitModel.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) def gaudi_SeamlessM4TTextToUnitForConditionalGeneration_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: """ Copied from SeamlessM4TTextToUnitForConditionalGeneration.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id ) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) lm_logits = self.lm_head(outputs[0]) masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() labels = labels.to(lm_logits.device) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def gaudi_SeamlessM4TTextToUnitForConditionalGeneration_prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs, ): """ Copied from SeamlessM4TTextToUnitForConditionalGeneration.prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ token_idx = kwargs.get("token_idx", None) # cut decoder_input_ids if past is used if past_key_values is not None: if token_idx is None: decoder_input_ids = decoder_input_ids[:, -1:] else: decoder_input_ids = torch.index_select(decoder_input_ids, 1, token_idx - 1) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, "token_idx": token_idx, "decoder_attention_mask": kwargs.get("decoder_attention_mask", None), } def gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Copied from SeamlessM4TCodeHifiGan._get_output_hifigan_lengths: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - fix torch.div issue """ def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return ( torch.div(input_length.item() + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + 1 ) def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 # conv_pre input_lengths = _conv_out_length(input_lengths, 7, 1, 3) # upsampler for i, (upsample_rate, kernel_size) in enumerate( zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) ): input_lengths = _transpose_conv_out_length( input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 ) # resblock for i in range(len(self.config.upsample_rates)): for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): for dil in dilation: input_lengths = _conv_out_length( input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil ) for dil in dilation: input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) # conv_post input_lengths = _conv_out_length(input_lengths, 7, 1, 3) return input_lengths def gaudi_SeamlessM4TForTextToSpeech_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: """ Copied from SeamlessM4TForTextToSpeech.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx args """ if labels is not None: if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn logger.warning( "This is the same forward method as `SeamlessM4TForTextToText`." "It doesn't use the text-to-unit model `SeamlessM4TTextToUnitForConditionalGeneration`." "If you want to generate speech, use the `.generate` method." ) encoder_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) encoder_attention_mask = attention_mask # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) lm_logits = self.lm_head(decoder_outputs[0]) masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() labels = labels.to(lm_logits.device) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: outputs = decoder_outputs + encoder_outputs output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) @torch.no_grad() def gaudi_SeamlessM4TForTextToSpeech_generate( self, input_ids: Optional[torch.Tensor] = None, return_intermediate_token_ids: Optional[bool] = None, tgt_lang: Optional[str] = None, spkr_id: Optional[int] = 0, **kwargs, ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: """ Copied from SeamlessM4TForTextToSpeech.generate: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - delete pad id for unit_ids output """ batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) if tgt_lang is None: raise ValueError("You must specify a `tgt_lang` to generate translated speech.") else: # also accept __xxx__ tgt_lang = tgt_lang.replace("__", "") for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: lang_code_to_id = getattr(self.generation_config, key, None) if lang_code_to_id is None: raise ValueError( f"""This model generation config doesn't have a `{key}` key which maps the target language to the right token id. Make sure to load the right generation config.""" ) elif tgt_lang not in lang_code_to_id: raise ValueError( f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports more languages for text translation than for speech synthesis.""" ) if kwargs.get("hpu_graphs", True): from habana_frameworks.torch.hpu import wrap_in_hpu_graph if not hasattr(self, "clear_cache"): self = wrap_in_hpu_graph(self) if not hasattr(self.t2u_model, "clear_cache"): self.t2u_model = wrap_in_hpu_graph(self.t2u_model) if not hasattr(self.vocoder, "clear_cache"): self.vocoder = wrap_in_hpu_graph(self.vocoder) kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) kwargs_text["output_hidden_states"] = True kwargs_text["return_dict_in_generate"] = True kwargs_text["output_scores"] = True text_decoder_input_ids = kwargs_text.get("decoder_input_ids") # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids # first generation text_generation_output = super(SeamlessM4TForTextToSpeech, self).generate(input_ids, **kwargs_text) sequences = text_generation_output.sequences # prepare second generation num_return_sequences = len(sequences) // batch_size attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] # take care of num_return_sequences # take most probable hidden states per batch of return_sequences # (batch_size*num_return_sequences, ...) -> (batch_size,...) if num_return_sequences > 1: idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) idx_most_probable_sequences_per_batch = ( idx_most_probable_sequences_per_batch + torch.arange(batch_size, device=self.device) * num_return_sequences ) sequences = sequences[idx_most_probable_sequences_per_batch] # get decoder last hidden state - must do a pass through the text decoder t2u_input_embeds = self.text_decoder( input_ids=sequences, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, ).last_hidden_state pad_token_id = self.generation_config.pad_token_id # Compute new attention mask seq_lens = (sequences != pad_token_id).int().sum(1) t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) kwargs_speech["attention_mask"] = t2u_model_attention_mask # Compute t2u decoder_input_ids t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) t2u_decoder_input_ids = torch.tensor( [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device ) kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids # second generation unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) seq_lens = (unit_ids != self.config.t2u_pad_token_id).int().sum(1) unit_ids = unit_ids[:, 0:seq_lens] output_unit_ids = unit_ids.detach().clone() # get rid of t2u_decoder_input_ids unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] # replace eos per pad unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id # offset of control symbols unit_ids = torch.where(unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) if return_intermediate_token_ids: return SeamlessM4TGenerationOutput( waveform=waveform, waveform_lengths=waveform_lengths, sequences=sequences, unit_sequences=output_unit_ids, ) return waveform, waveform_lengths def gaudi_SeamlessM4TForTextToSpeech_prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs, ): """ Copied from SeamlessM4TForTextToSpeech.prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py The only differences are: - add token_idx """ token_idx = kwargs.get("token_idx", None) # cut decoder_input_ids if past is used if past_key_values is not None: if token_idx is None: decoder_input_ids = decoder_input_ids[:, -1:] else: decoder_input_ids = torch.index_select(decoder_input_ids, 1, token_idx - 1) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, "token_idx": token_idx, "decoder_attention_mask": kwargs.get("decoder_attention_mask", None), }