optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py (343 lines of code) (raw):

# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. # Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import numpy as np import poptorch import torch import torch.nn.functional as F from transformers import Wav2Vec2ForPreTraining, Wav2Vec2Model from transformers.modeling_outputs import CausalLMOutput from transformers.models.wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Adapter, Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2ForCTC, Wav2Vec2ForPreTrainingOutput, Wav2Vec2GumbelVectorQuantizer, ) from optimum.utils import logging from ...modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register from .ipu_gumbel_vector_quantizer import IPUWav2Vec2GumbelVectorQuantizer from .ipu_layer_drop import IPUWav2Vec2Adapter, IPUWav2Vec2Encoder, IPUWav2Vec2EncoderStableLayerNorm logger = logging.get_logger(__name__) class IPUWav2Vec2Model(Wav2Vec2Model): def _get_feature_vector_attention_mask( self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None, ): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. # non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] # non_padded_lengths = attention_mask.cumsum(dim=-1)[:, 249999] non_padded_lengths = attention_mask.sum(dim=-1) output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) output_lengths = output_lengths.to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device, ) # these two operations makes sure that all values before the output lengths idxs are attended to attention_mask[ ( torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1, ) ] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask @register(Wav2Vec2ForPreTraining) class PipelinedWav2Vec2ForPreTraining(Wav2Vec2ForPreTraining, PipelineMixin): def change_wav2vec2_encoder_class(self, restore: bool): """Changes the encoder class to update its forward pass so that it uses our custom version. Args: restore: whether to restore the encoder to its original version or not. """ if self.config.do_stable_layer_norm: new_cls = Wav2Vec2EncoderStableLayerNorm if restore else IPUWav2Vec2EncoderStableLayerNorm else: new_cls = Wav2Vec2Encoder if restore else IPUWav2Vec2Encoder self.wav2vec2.encoder.__class__ = new_cls def change_wav2vec2_adapter_class(self, restore: bool): """Changes the adapter class to update its forward pass so that it uses our custom version. Args: restore: whether to restore the adapter to its original version or not. """ if self.config.add_adapter: self.wav2vec2.adapter.__class__ = Wav2Vec2Adapter if restore else IPUWav2Vec2Adapter def change_quantizer_class(self, restore: bool): """Changes the quantizer class to update its forward pass so that it uses our custom version. Args: restore: whether to restore the quantizer to its original version or not. """ self.quantizer.__class__ = Wav2Vec2GumbelVectorQuantizer if restore else IPUWav2Vec2GumbelVectorQuantizer def change_conv_eps(self, restore: bool): """Changes the epsilons in the layer norms of the conv layers to a value suitable for float16. Args: restore: whether to restore the epsilons to their original version or not. """ if self.config.feat_extract_norm != "layer": # In this case there is no layer norm in the conv layers return if restore: for i, conv_layer in enumerate(self.wav2vec2.feature_extractor.conv_layers): # Restore the original values conv_layer.layer_norm.eps = self.original_eps[i] else: self.original_eps = [] eps = 1e-4 for conv_layer in self.wav2vec2.feature_extractor.conv_layers: # Save the original values, to restore later self.original_eps.append(conv_layer.layer_norm.eps) conv_layer.layer_norm.eps = eps def _add_begin_block(self, module, name, ipu_id): poptorch.BeginBlock(module, name, ipu_id) def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Replaces some layers with IPU-specialised ones - Set eps to a stable value in float16 Recommended usage: ``` model = PipelinedWav2Vec2ForPreTraining(config).parallelize().half() ``` """ super().parallelize() self.wav2vec2.__class__ = IPUWav2Vec2Model self.change_wav2vec2_encoder_class(False) self.change_wav2vec2_adapter_class(False) self.change_quantizer_class(False) self.change_conv_eps(False) logger.info("---------- Device Allocation -----------") layers = [] # Conv layers for index, layer in enumerate(self.wav2vec2.feature_extractor.conv_layers): layers.append((f"Conv {index:<2}", layer)) # Positional Embedding layers.append(("Positional Embedding", self.wav2vec2.encoder.pos_conv_embed)) # Encoder layers for index, layer in enumerate(self.wav2vec2.encoder.layers): self._hooks.append(recomputation_checkpoint(layer)) layers.append((f"Encoder {index:<2}", layer)) # Project Hidden layers.append(("Project Hidden", self.project_hid)) # Quantizer layers.append(("Quantizer", self.quantizer)) # Project Quantizer layers.append(("Project Quantizer", self.project_q)) layer_ipu = get_layer_ipu(self.ipu_config, layers) for i, (name, layer) in enumerate(layers): logger.info(f"{name} --> IPU {layer_ipu[i]}") self._add_begin_block(layer, name, ipu_id=layer_ipu[i]) logger.info("---------------------------------------") def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is fully compatible with `transformers.Wav2Vec2ForPreTraining`. """ super().deparallelize() self.change_wav2vec2_encoder_class(True) self.change_wav2vec2_adapter_class(True) self.change_quantizer_class(True) self.change_conv_eps(True) self.wav2vec2.__class__ = Wav2Vec2Model return self def forward( self, input_values: Optional[torch.Tensor], gumbel_temperature: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.BoolTensor] = None, sampled_negative_indices: Optional[torch.BoolTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, mask_reduced: Optional[torch.Tensor] = None, reduce_selector: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Wav2Vec2ForPreTrainingOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if mask_time_indices is not None: mask_time_indices = mask_time_indices.to(torch.bool) if gumbel_temperature is None: gumbel_temperature = torch.tensor( self.quantizer.temperature, device=input_values.device, dtype=input_values.dtype ) outputs = self.wav2vec2( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, mask_time_indices=mask_time_indices, return_dict=return_dict, ) transformer_features, extract_features = outputs[0], outputs[1] if attention_mask is not None: # compute reduced attention_mask correponding to feature vectors attention_mask = self.wav2vec2._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) # GC. remove a (static sized) portion of the output tensors at unmasked indices # unmasked indices do not contribute to loss. removing them now alleviates memory requirements if reduce_selector is not None: batch_size, sequence_length, feature_size = extract_features.shape cropped_length = reduce_selector.shape[1] if batch_size > 1: reduce_selector += torch.arange(batch_size, device=input_values.device).unsqueeze(1) * sequence_length mask_time_indices = mask_reduced.to(torch.bool) extract_features = extract_features.view(-1, feature_size)[reduce_selector.long().view(-1)] extract_features = extract_features.reshape(batch_size, cropped_length, feature_size) _, _, feature_size = transformer_features.shape transformer_features = transformer_features.view(-1, feature_size)[reduce_selector.long().view(-1)] transformer_features = transformer_features.reshape(batch_size, cropped_length, feature_size) # 1. project all transformed features (including masked) to final vq dim transformer_features = self.project_hid(transformer_features) # 2. quantize all (unmasked) extracted features and project to final vq dim extract_features = self.dropout_features(extract_features) if isinstance(self.quantizer, IPUWav2Vec2GumbelVectorQuantizer): quantized_features, code_perplexity, prob_perplexity = self.quantizer( extract_features, gumbel_temperature.mean(), mask_time_indices=mask_time_indices, ) else: quantized_features, code_perplexity = self.quantizer( extract_features, mask_time_indices=mask_time_indices, ) prob_perplexity = None quantized_features = self.project_q(quantized_features) loss = contrastive_loss = diversity_loss = None if sampled_negative_indices is not None: batch_size, sequence_length, hidden_size = quantized_features.shape # for training, we sample negatives # 3. sample K negatives (distractors) quantized states for contrastive loss # if attention_mask is passed, make sure that padded feature vectors cannot be sampled # sample negative quantized vectors BTC => (BxT)C # Moved the negative sampling batch offsetting into the model if batch_size > 1: sampled_negative_indices += ( torch.arange(batch_size, device=input_values.device)[:, None, None] * sequence_length ) negative_quantized_features = quantized_features.view(-1, hidden_size)[ sampled_negative_indices.long().view(-1) ] negative_quantized_features = negative_quantized_features.view( batch_size, sequence_length, -1, hidden_size ).permute(2, 0, 1, 3) # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf logits = self.compute_contrastive_logits( quantized_features[None, :], negative_quantized_features, transformer_features, self.config.contrastive_logits_temperature, ) # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), # its cosine similarity will be masked neg_is_pos = (quantized_features == negative_quantized_features).all(-1) neg_is_pos = F.pad(neg_is_pos, (0, 0, 0, 0, 1, 0)) logits = logits.masked_fill(neg_is_pos, -1e3) # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) logits = logits.permute(1, 2, 0).reshape(batch_size * sequence_length, -1) target = ((1 - mask_time_indices.long()) * -100).flatten() contrastive_loss = F.cross_entropy(logits.float(), target, reduction="sum") # 7. compute diversity loss: \mathbf{L}_d num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups diversity_loss = ((num_codevectors - prob_perplexity) / num_codevectors) * mask_time_indices.sum() # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss if not return_dict: if loss is not None: return ( loss, transformer_features, quantized_features, code_perplexity, prob_perplexity, ) + outputs[2:] return ( transformer_features, quantized_features, code_perplexity, prob_perplexity, ) + outputs[2:] return Wav2Vec2ForPreTrainingOutput( loss=loss, projected_states=transformer_features, projected_quantized_states=quantized_features, codevector_perplexity=code_perplexity, hidden_states=outputs.hidden_states, attentions=outputs.attentions, contrastive_loss=contrastive_loss, diversity_loss=diversity_loss, ) @staticmethod def compute_contrastive_logits( target_features: torch.FloatTensor, negative_features: torch.FloatTensor, predicted_features: torch.FloatTensor, temperature: int = 0.1, ): """ Compute logits for contrastive loss based using cosine similarity as the distance measure between `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. """ target_features = torch.cat([target_features, negative_features], dim=0) logits = torch.cosine_similarity( predicted_features.float(), target_features.float(), dim=-1, eps=1e-4 ).type_as(target_features) # apply temperature logits = logits / temperature return logits @register(Wav2Vec2ForCTC) class PipelinedWav2Vec2ForCTC(Wav2Vec2ForCTC, PipelineMixin): def change_wav2vec2_encoder_class(self, restore: bool): """Changes the encoder class to update its forward pass so that it uses our custom version. Args: restore: whether to restore the encoder to its original version or not. """ if self.config.do_stable_layer_norm: new_cls = Wav2Vec2EncoderStableLayerNorm if restore else IPUWav2Vec2EncoderStableLayerNorm else: new_cls = Wav2Vec2Encoder if restore else IPUWav2Vec2Encoder self.wav2vec2.encoder.__class__ = new_cls def change_wav2vec2_adapter_class(self, restore: bool): """Changes the adapter class to update its forward pass so that it uses our custom version. Args: restore: whether to restore the adapter to its original version or not. """ if self.config.add_adapter: self.wav2vec2.adapter.__class__ = Wav2Vec2Adapter if restore else IPUWav2Vec2Adapter def change_conv_eps(self, restore: bool): """Changes the epsilons in the layer norms of the conv layers to a value suitable for float16. Args: restore: whether to restore the epsilons to their original version or not. """ if self.config.feat_extract_norm != "layer": # In this case there is no layer norm in the conv layers return if restore: for i, conv_layer in enumerate(self.wav2vec2.feature_extractor.conv_layers): # Restore the original values conv_layer.layer_norm.eps = self.original_eps[i] else: self.original_eps = [] for conv_layer in self.wav2vec2.feature_extractor.conv_layers: eps = 1e-4 if conv_layer.layer_norm.weight.dtype == torch.float16 else conv_layer.layer_norm.eps # Save the original values, to restore later self.original_eps.append(conv_layer.layer_norm.eps) conv_layer.layer_norm.eps = eps def _add_begin_block(self, module, name, ipu_id): poptorch.BeginBlock(module, name, ipu_id) def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Replaces some layers with IPU-specialised ones - Set eps to a stable value in float16 Recommended usage: ``` model = PipelinedWav2Vec2ForPreTraining(config).parallelize().half() ``` """ super().parallelize() self.wav2vec2.__class__ = IPUWav2Vec2Model self.freeze_feature_encoder() self.change_wav2vec2_encoder_class(False) self.change_wav2vec2_adapter_class(False) self.change_conv_eps(False) if self.ipu_config._ipus_per_replica != 1: logger.info("---------- Device Allocation -----------") layers = [] # Conv layers for index, layer in enumerate(self.wav2vec2.feature_extractor.conv_layers): layers.append((f"Conv {index:<2}", layer)) # Positional Embedding layers.append(("Positional Embedding", self.wav2vec2.encoder.pos_conv_embed)) # Encoder layers for index, layer in enumerate(self.wav2vec2.encoder.layers): self._hooks.append(recomputation_checkpoint(layer)) layers.append((f"Encoder {index:<2}", layer)) # Project Hidden layers.append(("Project Hidden", self.lm_head)) layer_ipu = get_layer_ipu(self.ipu_config, layers) for i, (name, layer) in enumerate(layers): logger.info(f"{name} --> IPU {layer_ipu[i]}") self._add_begin_block(layer, name, ipu_id=layer_ipu[i]) logger.info("---------------------------------------") def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is fully compatible with `transformers.Wav2Vec2ForPreTraining`. """ super().deparallelize() self.change_wav2vec2_encoder_class(True) self.change_wav2vec2_adapter_class(True) self.change_conv_eps(True) self.wav2vec2.__class__ = Wav2Vec2Model return self def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.wav2vec2( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) ) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) # ctc_loss doesn't support fp16 log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1).transpose(0, 1) loss_fn = torch.nn.CTCLoss( blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) loss = loss_fn(log_probs, labels, input_lengths, target_lengths) loss = poptorch.identity_loss(loss, "none") if not return_dict: if loss is not None: return loss, logits return (logits, hidden_states) return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) def _sample_negative_indices( features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None, ): """ Sample `num_negatives` vectors from feature vectors. """ batch_size, sequence_length = features_shape # generate indices of the positive vectors themselves, repeat them `num_negatives` times sequence_length_range = np.arange(sequence_length) # get `num_negatives` random vector indices from the same utterance sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) mask_time_indices = ( mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool) ) for batch_idx in range(batch_size): high = mask_time_indices[batch_idx].sum() - 1 mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) # avoid sampling the same positive vector, but keep the distribution uniform sampled_indices[sampled_indices >= feature_indices] += 1 # remap to actual indices sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] # Moved the offsetting into the model to stop issues with gradient accumulation # sampled_negative_indices[batch_idx] += batch_idx * sequence_length return sampled_negative_indices