optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py (343 lines of code) (raw):
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import poptorch
import torch
import torch.nn.functional as F
from transformers import Wav2Vec2ForPreTraining, Wav2Vec2Model
from transformers.modeling_outputs import CausalLMOutput
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Adapter,
Wav2Vec2Encoder,
Wav2Vec2EncoderStableLayerNorm,
Wav2Vec2ForCTC,
Wav2Vec2ForPreTrainingOutput,
Wav2Vec2GumbelVectorQuantizer,
)
from optimum.utils import logging
from ...modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register
from .ipu_gumbel_vector_quantizer import IPUWav2Vec2GumbelVectorQuantizer
from .ipu_layer_drop import IPUWav2Vec2Adapter, IPUWav2Vec2Encoder, IPUWav2Vec2EncoderStableLayerNorm
logger = logging.get_logger(__name__)
class IPUWav2Vec2Model(Wav2Vec2Model):
def _get_feature_vector_attention_mask(
self,
feature_vector_length: int,
attention_mask: torch.LongTensor,
add_adapter=None,
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
# non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
# non_padded_lengths = attention_mask.cumsum(dim=-1)[:, 249999]
non_padded_lengths = attention_mask.sum(dim=-1)
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = output_lengths.to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[
(
torch.arange(attention_mask.shape[0], device=attention_mask.device),
output_lengths - 1,
)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
@register(Wav2Vec2ForPreTraining)
class PipelinedWav2Vec2ForPreTraining(Wav2Vec2ForPreTraining, PipelineMixin):
def change_wav2vec2_encoder_class(self, restore: bool):
"""Changes the encoder class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the encoder to its original version or not.
"""
if self.config.do_stable_layer_norm:
new_cls = Wav2Vec2EncoderStableLayerNorm if restore else IPUWav2Vec2EncoderStableLayerNorm
else:
new_cls = Wav2Vec2Encoder if restore else IPUWav2Vec2Encoder
self.wav2vec2.encoder.__class__ = new_cls
def change_wav2vec2_adapter_class(self, restore: bool):
"""Changes the adapter class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the adapter to its original version or not.
"""
if self.config.add_adapter:
self.wav2vec2.adapter.__class__ = Wav2Vec2Adapter if restore else IPUWav2Vec2Adapter
def change_quantizer_class(self, restore: bool):
"""Changes the quantizer class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the quantizer to its original version or not.
"""
self.quantizer.__class__ = Wav2Vec2GumbelVectorQuantizer if restore else IPUWav2Vec2GumbelVectorQuantizer
def change_conv_eps(self, restore: bool):
"""Changes the epsilons in the layer norms of the conv layers to a value suitable for float16.
Args:
restore: whether to restore the epsilons to their original version or not.
"""
if self.config.feat_extract_norm != "layer":
# In this case there is no layer norm in the conv layers
return
if restore:
for i, conv_layer in enumerate(self.wav2vec2.feature_extractor.conv_layers):
# Restore the original values
conv_layer.layer_norm.eps = self.original_eps[i]
else:
self.original_eps = []
eps = 1e-4
for conv_layer in self.wav2vec2.feature_extractor.conv_layers:
# Save the original values, to restore later
self.original_eps.append(conv_layer.layer_norm.eps)
conv_layer.layer_norm.eps = eps
def _add_begin_block(self, module, name, ipu_id):
poptorch.BeginBlock(module, name, ipu_id)
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Replaces some layers with IPU-specialised ones
- Set eps to a stable value in float16
Recommended usage:
```
model = PipelinedWav2Vec2ForPreTraining(config).parallelize().half()
```
"""
super().parallelize()
self.wav2vec2.__class__ = IPUWav2Vec2Model
self.change_wav2vec2_encoder_class(False)
self.change_wav2vec2_adapter_class(False)
self.change_quantizer_class(False)
self.change_conv_eps(False)
logger.info("---------- Device Allocation -----------")
layers = []
# Conv layers
for index, layer in enumerate(self.wav2vec2.feature_extractor.conv_layers):
layers.append((f"Conv {index:<2}", layer))
# Positional Embedding
layers.append(("Positional Embedding", self.wav2vec2.encoder.pos_conv_embed))
# Encoder layers
for index, layer in enumerate(self.wav2vec2.encoder.layers):
self._hooks.append(recomputation_checkpoint(layer))
layers.append((f"Encoder {index:<2}", layer))
# Project Hidden
layers.append(("Project Hidden", self.project_hid))
# Quantizer
layers.append(("Quantizer", self.quantizer))
# Project Quantizer
layers.append(("Project Quantizer", self.project_q))
layer_ipu = get_layer_ipu(self.ipu_config, layers)
for i, (name, layer) in enumerate(layers):
logger.info(f"{name} --> IPU {layer_ipu[i]}")
self._add_begin_block(layer, name, ipu_id=layer_ipu[i])
logger.info("---------------------------------------")
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers.Wav2Vec2ForPreTraining`.
"""
super().deparallelize()
self.change_wav2vec2_encoder_class(True)
self.change_wav2vec2_adapter_class(True)
self.change_quantizer_class(True)
self.change_conv_eps(True)
self.wav2vec2.__class__ = Wav2Vec2Model
return self
def forward(
self,
input_values: Optional[torch.Tensor],
gumbel_temperature: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.BoolTensor] = None,
sampled_negative_indices: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
mask_reduced: Optional[torch.Tensor] = None,
reduce_selector: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Wav2Vec2ForPreTrainingOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if mask_time_indices is not None:
mask_time_indices = mask_time_indices.to(torch.bool)
if gumbel_temperature is None:
gumbel_temperature = torch.tensor(
self.quantizer.temperature, device=input_values.device, dtype=input_values.dtype
)
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
return_dict=return_dict,
)
transformer_features, extract_features = outputs[0], outputs[1]
if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors
attention_mask = self.wav2vec2._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
# GC. remove a (static sized) portion of the output tensors at unmasked indices
# unmasked indices do not contribute to loss. removing them now alleviates memory requirements
if reduce_selector is not None:
batch_size, sequence_length, feature_size = extract_features.shape
cropped_length = reduce_selector.shape[1]
if batch_size > 1:
reduce_selector += torch.arange(batch_size, device=input_values.device).unsqueeze(1) * sequence_length
mask_time_indices = mask_reduced.to(torch.bool)
extract_features = extract_features.view(-1, feature_size)[reduce_selector.long().view(-1)]
extract_features = extract_features.reshape(batch_size, cropped_length, feature_size)
_, _, feature_size = transformer_features.shape
transformer_features = transformer_features.view(-1, feature_size)[reduce_selector.long().view(-1)]
transformer_features = transformer_features.reshape(batch_size, cropped_length, feature_size)
# 1. project all transformed features (including masked) to final vq dim
transformer_features = self.project_hid(transformer_features)
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(extract_features)
if isinstance(self.quantizer, IPUWav2Vec2GumbelVectorQuantizer):
quantized_features, code_perplexity, prob_perplexity = self.quantizer(
extract_features,
gumbel_temperature.mean(),
mask_time_indices=mask_time_indices,
)
else:
quantized_features, code_perplexity = self.quantizer(
extract_features,
mask_time_indices=mask_time_indices,
)
prob_perplexity = None
quantized_features = self.project_q(quantized_features)
loss = contrastive_loss = diversity_loss = None
if sampled_negative_indices is not None:
batch_size, sequence_length, hidden_size = quantized_features.shape
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
# sample negative quantized vectors BTC => (BxT)C
# Moved the negative sampling batch offsetting into the model
if batch_size > 1:
sampled_negative_indices += (
torch.arange(batch_size, device=input_values.device)[:, None, None] * sequence_length
)
negative_quantized_features = quantized_features.view(-1, hidden_size)[
sampled_negative_indices.long().view(-1)
]
negative_quantized_features = negative_quantized_features.view(
batch_size, sequence_length, -1, hidden_size
).permute(2, 0, 1, 3)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
logits = self.compute_contrastive_logits(
quantized_features[None, :],
negative_quantized_features,
transformer_features,
self.config.contrastive_logits_temperature,
)
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
# its cosine similarity will be masked
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
neg_is_pos = F.pad(neg_is_pos, (0, 0, 0, 0, 1, 0))
logits = logits.masked_fill(neg_is_pos, -1e3)
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
logits = logits.permute(1, 2, 0).reshape(batch_size * sequence_length, -1)
target = ((1 - mask_time_indices.long()) * -100).flatten()
contrastive_loss = F.cross_entropy(logits.float(), target, reduction="sum")
# 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
diversity_loss = ((num_codevectors - prob_perplexity) / num_codevectors) * mask_time_indices.sum()
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
if not return_dict:
if loss is not None:
return (
loss,
transformer_features,
quantized_features,
code_perplexity,
prob_perplexity,
) + outputs[2:]
return (
transformer_features,
quantized_features,
code_perplexity,
prob_perplexity,
) + outputs[2:]
return Wav2Vec2ForPreTrainingOutput(
loss=loss,
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=code_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
contrastive_loss=contrastive_loss,
diversity_loss=diversity_loss,
)
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 0.1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
"""
target_features = torch.cat([target_features, negative_features], dim=0)
logits = torch.cosine_similarity(
predicted_features.float(), target_features.float(), dim=-1, eps=1e-4
).type_as(target_features)
# apply temperature
logits = logits / temperature
return logits
@register(Wav2Vec2ForCTC)
class PipelinedWav2Vec2ForCTC(Wav2Vec2ForCTC, PipelineMixin):
def change_wav2vec2_encoder_class(self, restore: bool):
"""Changes the encoder class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the encoder to its original version or not.
"""
if self.config.do_stable_layer_norm:
new_cls = Wav2Vec2EncoderStableLayerNorm if restore else IPUWav2Vec2EncoderStableLayerNorm
else:
new_cls = Wav2Vec2Encoder if restore else IPUWav2Vec2Encoder
self.wav2vec2.encoder.__class__ = new_cls
def change_wav2vec2_adapter_class(self, restore: bool):
"""Changes the adapter class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the adapter to its original version or not.
"""
if self.config.add_adapter:
self.wav2vec2.adapter.__class__ = Wav2Vec2Adapter if restore else IPUWav2Vec2Adapter
def change_conv_eps(self, restore: bool):
"""Changes the epsilons in the layer norms of the conv layers to a value suitable for float16.
Args:
restore: whether to restore the epsilons to their original version or not.
"""
if self.config.feat_extract_norm != "layer":
# In this case there is no layer norm in the conv layers
return
if restore:
for i, conv_layer in enumerate(self.wav2vec2.feature_extractor.conv_layers):
# Restore the original values
conv_layer.layer_norm.eps = self.original_eps[i]
else:
self.original_eps = []
for conv_layer in self.wav2vec2.feature_extractor.conv_layers:
eps = 1e-4 if conv_layer.layer_norm.weight.dtype == torch.float16 else conv_layer.layer_norm.eps
# Save the original values, to restore later
self.original_eps.append(conv_layer.layer_norm.eps)
conv_layer.layer_norm.eps = eps
def _add_begin_block(self, module, name, ipu_id):
poptorch.BeginBlock(module, name, ipu_id)
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Replaces some layers with IPU-specialised ones
- Set eps to a stable value in float16
Recommended usage:
```
model = PipelinedWav2Vec2ForPreTraining(config).parallelize().half()
```
"""
super().parallelize()
self.wav2vec2.__class__ = IPUWav2Vec2Model
self.freeze_feature_encoder()
self.change_wav2vec2_encoder_class(False)
self.change_wav2vec2_adapter_class(False)
self.change_conv_eps(False)
if self.ipu_config._ipus_per_replica != 1:
logger.info("---------- Device Allocation -----------")
layers = []
# Conv layers
for index, layer in enumerate(self.wav2vec2.feature_extractor.conv_layers):
layers.append((f"Conv {index:<2}", layer))
# Positional Embedding
layers.append(("Positional Embedding", self.wav2vec2.encoder.pos_conv_embed))
# Encoder layers
for index, layer in enumerate(self.wav2vec2.encoder.layers):
self._hooks.append(recomputation_checkpoint(layer))
layers.append((f"Encoder {index:<2}", layer))
# Project Hidden
layers.append(("Project Hidden", self.lm_head))
layer_ipu = get_layer_ipu(self.ipu_config, layers)
for i, (name, layer) in enumerate(layers):
logger.info(f"{name} --> IPU {layer_ipu[i]}")
self._add_begin_block(layer, name, ipu_id=layer_ipu[i])
logger.info("---------------------------------------")
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers.Wav2Vec2ForPreTraining`.
"""
super().deparallelize()
self.change_wav2vec2_encoder_class(True)
self.change_wav2vec2_adapter_class(True)
self.change_conv_eps(True)
self.wav2vec2.__class__ = Wav2Vec2Model
return self
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
# ctc_loss doesn't support fp16
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1).transpose(0, 1)
loss_fn = torch.nn.CTCLoss(
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
loss = loss_fn(log_probs, labels, input_lengths, target_lengths)
loss = poptorch.identity_loss(loss, "none")
if not return_dict:
if loss is not None:
return loss, logits
return (logits, hidden_states)
return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def _sample_negative_indices(
features_shape: Tuple,
num_negatives: int,
mask_time_indices: Optional[np.ndarray] = None,
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length = features_shape
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
)
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
# Moved the offsetting into the model to stop issues with gradient accumulation
# sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices