optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py (423 lines of code) (raw):

# coding=utf-8 # Copyright 2022 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random from typing import Optional, Tuple, Union import torch from habana_frameworks.torch.hpu import get_device_name from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_outputs import ( BaseModelOutput, CausalLMOutput, Wav2Vec2BaseModelOutput, ) from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION, Wav2Vec2Attention, logger try: from habana_frameworks.torch.hpex.kernels import CTCLoss custom_ctc_loss_fwd = CTCLoss.apply except ImportError: print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0") custom_ctc_loss_fwd = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, ) -> torch.Tensor: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L135 The only differences are (1) that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers), (2) epsilon is generated on HPU instead of CPU, (3) check to ensure indices are not larger than sequence length is re-written to avoid host sync. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`" ) # epsilon is used for probabilistic rounding epsilon = torch.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = ( attention_mask.detach().sum(-1).tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] ) # SpecAugment mask to fill spec_aug_mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device="hpu") spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = torch.tensor( random.sample(range(input_length - (mask_length - 1)), num_masked_span), dtype=torch.int32 ) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = torch.cat( [ spec_aug_mask_idx, torch.ones(max_num_masked_span - num_masked_span, dtype=torch.int32) * dummy_mask_idx, ] ) spec_aug_mask_idxs.append(spec_aug_mask_idx.to(dtype=torch.long)) spec_aug_mask_idxs = torch.vstack(spec_aug_mask_idxs).to("hpu") # expand masked indices to masked spans spec_aug_mask_idxs = torch.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) # add offset to the starting indexes so that indexes now create a span offsets = torch.arange(mask_length, device="hpu")[None, None, :] offsets = torch.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 else: mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) inverse_mask = torch.bitwise_not(mask) spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask # scatter indices to mask spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1) return spec_aug_mask def _gaudi_wav2vec2_sample_negative_indices( features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[torch.Tensor] = None ): """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L254 The only difference is that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers). """ batch_size, sequence_length = features_shape # generate indices of the positive vectors themselves, repeat them `num_negatives` times sequence_length_range = torch.arange(sequence_length, device="hpu") # get `num_negatives` random vector indices from the same utterance sampled_negative_indices = torch.zeros( shape=(batch_size, sequence_length, num_negatives), dtype=torch.int32, device="hpu" ) mask_time_indices = ( mask_time_indices.bool() if mask_time_indices is not None else torch.ones(features_shape, dtype=torch.bool, device="hpu") ) for batch_idx in range(batch_size): high = mask_time_indices[batch_idx].sum() - 1 mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] feature_indices = torch.broadcast_to(torch.arange(high + 1, device="hpu")[:, None], (high + 1, num_negatives)) sampled_indices = torch.randint(0, high, size=(high + 1, num_negatives), dtype=torch.int16, device="hpu") # avoid sampling the same positive vector, but keep the distribution uniform sampled_indices[sampled_indices >= feature_indices] += 1 # remap to actual indices sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] # correct for batch size sampled_negative_indices[batch_idx] += batch_idx * sequence_length return sampled_negative_indices def gaudi_wav2vec2_encoder_forward( self, hidden_states: torch.tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): """ Copied from Transformers: https://github.com/huggingface/transformers/blob/7790943c91411f4234d11dfbf4c2f21ce7caf088/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L755 The only difference is that torch.rand device is set to 'hpu' (required to capture operation as part of HPU graph) """ all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, output_attentions, ) else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def gaudi_wav2vec2_forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return Wav2Vec2BaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def _gaudi_wav2vec2_mask_hidden_states( self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1227 Differences are that (1) `mask_time_indices` is not moved to the current device and converted into boolean because this is already done in _compute_mask_indices. (2) index_put operation on hidden_states is replaced by combination of simpler ops (more suitable for HPU graphs) """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return hidden_states # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.size() if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _gaudi_wav2vec2_compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) # replacement of index_put with combination of simpler ops. Assumption made about sizes of hidden_states (3d), # mask_time_indices (2d), self.masked_spec_embed (1d), for any other combination better to go back to original code using index_put. # hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) inverse_mask_time_indices = torch.bitwise_not(mask_time_indices) hidden_states = hidden_states * inverse_mask_time_indices.unsqueeze(2) + self.masked_spec_embed.to( hidden_states.dtype ).expand(hidden_states.size()) * mask_time_indices.unsqueeze(2) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _gaudi_wav2vec2_compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states def gaudi_wav2vec2forctc_forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutput]: """ copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950 only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid changing flattened_targets tensor shapes across training iterations. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None and labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") outputs = self.wav2vec2( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long, device="hpu") ) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) # ctc_loss doesn't support fp16 log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) if get_device_name() == "GAUDI" or custom_ctc_loss_fwd is None: flattened_targets = labels.masked_select(labels_mask) loss = torch.nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) else: flattened_targets = labels loss = custom_ctc_loss_fwd( log_probs, flattened_targets, input_lengths, target_lengths, self.config.pad_token_id, self.config.ctc_loss_reduction, self.config.ctc_zero_infinity, ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L2290 v4.38.2 implementation caused accuracy issue to run pytest Wav2Vec2RobustModelTest. """ hidden_states = hidden_states.unsqueeze(1) hidden_states = torch.nn.functional.unfold( hidden_states, (self.kernel_size, self.in_conv_dim), stride=(1, self.in_conv_dim), dilation=(self.dilation, 1), ) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.kernel(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states class GaudiWav2Vec2SdpaAttention(Wav2Vec2Attention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, config: Optional[Wav2Vec2Config] = None, ): super().__init__( embed_dim, num_heads, dropout, is_decoder, bias, is_causal, config, ) self.use_flash_attention = True if os.getenv("USE_FLASH_ATTENTION") == "1" else False self.flash_attention_fast_softmax = True if os.getenv("FLASH_ATTENTION_FAST_SOFTMAX") == "1" else False self.flash_attention_recompute = True if os.getenv("FLASH_ATTENTION_RECOMPUTE") == "1" else False def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from Wav2Vec2SdpaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py The only difference is If the `USE_FLASH_ATTENTION` switch is enabled, then use the HPU's fused SDPA; otherwise, use PyTorch's native SDPA """ """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states, key_value_states=key_value_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) query_states = self._shape(query_states, tgt_len, bsz) # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False if self.use_flash_attention and FusedSDPA: if tgt_len == 1: # next token softmax_mode = True if os.getenv("QUANT_CONFIG", "") else False recompute_mode = False else: # first token softmax_mode = "fast" if self.flash_attention_fast_softmax else "None" recompute_mode = self.flash_attention_recompute attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, self.dropout if self.training else 0.0, is_causal, None, # scale ; Non -> default scale softmax_mode, # "None" -> Non fast softmax recompute_mode, # -> recompute_mode = false ) else: # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value