optimum/graphcore/models/hubert/modeling_hubert.py (114 lines of code) (raw):
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import poptorch
import torch
from transformers import HubertForCTC, HubertForSequenceClassification
from transformers.modeling_outputs import CausalLMOutput
from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm
from optimum.utils import logging
from ...modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register
from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm
logger = logging.get_logger(__name__)
@register(HubertForSequenceClassification)
class PipelinedHubertForSequenceClassification(HubertForSequenceClassification, PipelineMixin):
def change_hubert_encoder_class(self, restore: bool):
"""Changes the encoder class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the encoder to its original version or not.
"""
if self.config.do_stable_layer_norm:
new_cls = HubertEncoderStableLayerNorm if restore else IPUHubertEncoderStableLayerNorm
else:
new_cls = HubertEncoder if restore else IPUHubertEncoder
self.hubert.encoder.__class__ = new_cls
def parallelize(self):
super().parallelize()
self.change_hubert_encoder_class(False)
self.hubert.feature_extractor = poptorch.BeginBlock(self.hubert.feature_extractor, ipu_id=0)
self.hubert.feature_projection = poptorch.BeginBlock(self.hubert.feature_projection, ipu_id=0)
self.hubert.encoder = poptorch.BeginBlock(self.hubert.encoder, ipu_id=0)
layer_ipu = get_layer_ipu(self.ipu_config, self.hubert.encoder.layers)
for index, layer in enumerate(self.hubert.encoder.layers):
# Put checkpoints on every encoder layer
h = recomputation_checkpoint(layer)
self._hooks.append(h)
ipu = layer_ipu[index]
self.hubert.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
last_ipu = self.ipu_config._ipus_per_replica - 1
self.projector = poptorch.BeginBlock(self.projector, ipu_id=last_ipu)
self.classifier = poptorch.BeginBlock(self.classifier, ipu_id=last_ipu)
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
"""
super().deparallelize()
self.change_hubert_encoder_class(True)
return self
@register(HubertForCTC)
class PipelinedHubertCTC(HubertForCTC, PipelineMixin):
def change_hubert_encoder_class(self, restore: bool):
"""Changes the encoder class to update its forward pass so that it uses our custom version.
Args:
restore: whether to restore the encoder to its original version or not.
"""
if self.config.do_stable_layer_norm:
new_cls = HubertEncoderStableLayerNorm if restore else IPUHubertEncoderStableLayerNorm
else:
new_cls = HubertEncoder if restore else IPUHubertEncoder
self.hubert.encoder.__class__ = new_cls
def _add_begin_block(self, module, name, ipu_id):
poptorch.BeginBlock(module, name, ipu_id)
def parallelize(self):
super().parallelize()
self.freeze_feature_encoder()
self.change_hubert_encoder_class(False)
if self.ipu_config._ipus_per_replica != 1:
logger.info("---------- Device Allocation -----------")
layers = []
# Conv layers
for index, layer in enumerate(self.hubert.feature_extractor.conv_layers):
layers.append((f"Conv {index:<2}", layer))
# Positional Embedding
layers.append(("Positional Embedding", self.hubert.encoder.pos_conv_embed))
# Encoder layers
for index, layer in enumerate(self.hubert.encoder.layers):
self._hooks.append(recomputation_checkpoint(layer))
layers.append((f"Encoder {index:<2}", layer))
# Project Hidden
layers.append(("Project Hidden", self.lm_head))
layer_ipu = get_layer_ipu(self.ipu_config, layers)
for i, (name, layer) in enumerate(layers):
logger.info(f"{name} --> IPU {layer_ipu[i]}")
self._add_begin_block(layer, name, ipu_id=layer_ipu[i])
logger.info("---------------------------------------")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
"""
super().deparallelize()
self.change_hubert_encoder_class(True)
return self
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
# flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1).transpose(0, 1)
loss_fn = torch.nn.CTCLoss(
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
loss = loss_fn(log_probs, labels, input_lengths, target_lengths)
loss = poptorch.identity_loss(loss, "none")
if not return_dict:
if loss is not None:
return loss, logits
return (logits, hidden_states)
return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)