optimum/graphcore/models/bert/modeling_bert.py (368 lines of code) (raw):

# Copyright (c) 2023 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import poptorch import torch import torch.nn as nn import torch.nn.functional as F from scipy.stats import truncnorm from transformers import ( BertForMaskedLM, BertForMultipleChoice, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, ) from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertModel, BertSelfAttention from optimum.utils import logging from ...modeling_utils import ( OnehotGather, PipelineMixin, SerializedEmbedding, SerializedLinear, get_layer_ipu, outline_attribute, recomputation_checkpoint, register, ) from .bert_fused_attention import BertFusedSelfAttention logger = logging.get_logger(__name__) @register(BertModel) class PipelinedBertModel(BertModel, PipelineMixin): def __init__(self, config): super().__init__(config) def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Replaces self-attention layers with fused-qkv self-attention layers - (If enabled) Replaces the word embedding with a SerializedEmbedding - Adds recomputation checkpoints """ super().parallelize() # Use faster fused-qkv self-attention for layer in self.encoder.layer: layer.attention.self.__class__ = BertFusedSelfAttention logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") if self.ipu_config.embedding_serialization_factor > 1: self.embeddings.word_embeddings = SerializedEmbedding.from_model( self.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor ) self.embeddings = poptorch.BeginBlock(self.embeddings, "Embedding", ipu_id=0) hs = outline_attribute(self.embeddings.LayerNorm, "embedding") self._hooks.extend(hs) layer_ipu = get_layer_ipu(self.ipu_config, self.encoder.layer) for index, layer in enumerate(self.encoder.layer): ipu = layer_ipu[index] if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: h = recomputation_checkpoint(layer) self._hooks.append(h) self.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") logger.info("Pooler --> IPU 0") self.pooler = poptorch.BeginBlock(self.pooler, "Pooler", ipu_id=0) return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is compatible with the original model. """ super().deparallelize() for layer in self.encoder.layer: layer.attention.self.__class__ = BertSelfAttention # Deserialize the serialized word embedding if self.ipu_config.embedding_serialization_factor > 1: self.embeddings.word_embeddings = self.embeddings.word_embeddings.to_model() return self @register(BertForPreTraining) class PipelinedBertForPreTraining(BertForPreTraining, PipelineMixin): """ BertForPretraining transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedBertForPretraining(config).parallelize().half().train() ``` """ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Replaces self-attention layers with fused-qkv self-attention layers - (If enabled) Replaces the word embedding projection with a SerializedLinear layer - Adds recomputation checkpoints """ super().parallelize() # Use faster fused-qkv self-attention for layer in self.bert.encoder.layer: layer.attention.self.__class__ = BertFusedSelfAttention if self.ipu_config.embedding_serialization_factor > 1: self.cls.predictions.decoder = SerializedLinear.from_model( self.cls.predictions.decoder, self.ipu_config.embedding_serialization_factor ) self.tie_weights() logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0) # Preventing the embeddings.LayerNorm from being outlined with the encoder.layer.LayerNorm # improves the tile mapping of the pipeline stashes hs = outline_attribute(self.bert.embeddings.LayerNorm, "embeddings") self._hooks.extend(hs) layer_ipu = get_layer_ipu(self.ipu_config, self.bert.encoder.layer) for index, layer in enumerate(self.bert.encoder.layer): ipu = layer_ipu[index] if self.ipu_config.recompute_checkpoint_every_layer: h = recomputation_checkpoint(layer) self._hooks.append(h) self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") logger.info("Pooler --> IPU 0") self.bert.pooler = poptorch.BeginBlock(self.bert.pooler, "Pooler", ipu_id=0) logger.info("Classifier --> IPU 0") self.cls = poptorch.BeginBlock(self.cls, "Classifier", ipu_id=0) logger.info("-----------------------------------------------------------") return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is compatible with the original model. """ super().deparallelize() for layer in self.bert.encoder.layer: layer.attention.self.__class__ = BertSelfAttention if isinstance(self.cls.predictions.decoder, SerializedLinear): self.cls.predictions.decoder = self.cls.predictions.decoder.to_model() self.tie_weights() return self def _init_weights(self, module): """Initialize the weights""" def truncated_normal_(tensor, mean=0, std=1): """ Truncated Normal distribution, truncated at 2 sigma """ r = torch.tensor(truncnorm.rvs(-2, 2, loc=mean, scale=std, size=tensor.shape)) tensor.data.copy_(r) if isinstance(module, nn.Linear): truncated_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): truncated_normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, next_sentence_label: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output, pooled_output = outputs[:2] if labels is not None: if hasattr(self.config, "max_num_masked_tokens"): # Select only the masked tokens for the classifier labels, positions = torch.topk(labels, k=self.config.max_num_masked_tokens, dim=1) sequence_output = self.gather_indices(sequence_output, positions) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) total_loss = None if labels is not None and next_sentence_label is not None: masked_lm_loss = F.cross_entropy( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1), ).float() next_sentence_loss = F.cross_entropy( seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) ).float() total_loss = poptorch.identity_loss(masked_lm_loss + next_sentence_loss, reduction="none") # If labels are provided (training mode) only output the loss if not return_dict: output = (prediction_scores, seq_relationship_score) + outputs[2:] return (total_loss,) if total_loss is not None else output return BertForPreTrainingOutput( loss=total_loss, prediction_logits=prediction_scores if total_loss is None else None, seq_relationship_logits=seq_relationship_score if total_loss is None else None, hidden_states=outputs.hidden_states if total_loss is None else None, attentions=outputs.attentions if total_loss is None else None, ) @register(BertForMaskedLM) class PipelinedBertForMaskedLM(BertForMaskedLM, PipelineMixin): """ BertForMaskedLM transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedBertForMaskedLM(config).parallelize().half().train() ``` """ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Replaces self-attention layers with fused-qkv self-attention layers - (If enabled) Replaces the word embedding projection with a SerializedLinear layer - Adds recomputation checkpoints """ super().parallelize() # Use faster fused-qkv self-attention for layer in self.bert.encoder.layer: layer.attention.self.__class__ = BertFusedSelfAttention if self.ipu_config.embedding_serialization_factor > 1: self.cls.predictions.decoder = SerializedLinear.from_model( self.cls.predictions.decoder, self.ipu_config.embedding_serialization_factor ) self.tie_weights() logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0) # Preventing the embeddings.LayerNorm from being outlined with the encoder.layer.LayerNorm # improves the tile mapping of the pipeline stashes hs = outline_attribute(self.bert.embeddings.LayerNorm, "embeddings") self._hooks.extend(hs) layer_ipu = get_layer_ipu(self.ipu_config, self.bert.encoder.layer) for index, layer in enumerate(self.bert.encoder.layer): ipu = layer_ipu[index] if self.ipu_config.recompute_checkpoint_every_layer: h = recomputation_checkpoint(layer) self._hooks.append(h) self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") logger.info("Classifier --> IPU 0") self.cls = poptorch.BeginBlock(self.cls, "Classifier", ipu_id=0) logger.info("-----------------------------------------------------------") return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is compatible with the original model. """ super().deparallelize() for layer in self.bert.encoder.layer: layer.attention.self.__class__ = BertSelfAttention if isinstance(self.cls.predictions.decoder, SerializedLinear): self.cls.predictions.decoder = self.cls.predictions.decoder.to_model() self.tie_weights() return self def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.training: outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if hasattr(self.config, "max_num_masked_tokens"): # Select only the masked tokens for the classifier labels, positions = torch.topk(labels, k=self.config.max_num_masked_tokens, dim=1) sequence_output = self.gather_indices(sequence_output, positions) prediction_scores = self.cls(sequence_output) outputs = (prediction_scores,) + outputs[2:] masked_lm_loss = F.cross_entropy( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) ).float() # When training only return the loss if return_dict: return MaskedLMOutput(loss=masked_lm_loss) else: return (masked_lm_loss,) else: return super().forward( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, labels=labels, return_dict=return_dict, ) class BertPipelineMixin(PipelineMixin): def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - Replaces self-attention layers with fused-qkv self-attention layers - (If enabled) Replaces the word embedding with a SerializedEmbedding - Adds recomputation checkpoints """ super().parallelize() # Use faster fused-qkv self-attention for layer in self.bert.encoder.layer: layer.attention.self.__class__ = BertFusedSelfAttention logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") if self.ipu_config.embedding_serialization_factor > 1: self.bert.embeddings.word_embeddings = SerializedEmbedding.from_model( self.bert.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor ) self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0) hs = outline_attribute(self.bert.embeddings.LayerNorm, "embedding") self._hooks.extend(hs) layer_ipu = get_layer_ipu(self.ipu_config, self.bert.encoder.layer) for index, layer in enumerate(self.bert.encoder.layer): ipu = layer_ipu[index] if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: h = recomputation_checkpoint(layer) self._hooks.append(h) self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is compatible with the original model. """ super().deparallelize() for layer in self.bert.encoder.layer: layer.attention.self.__class__ = BertSelfAttention # Deserialize the serialized word embedding if self.ipu_config.embedding_serialization_factor > 1: self.bert.embeddings.word_embeddings = self.bert.embeddings.word_embeddings.to_model() return self @register(BertForSequenceClassification) class PipelinedBertForSequenceClassification(BertForSequenceClassification, BertPipelineMixin): """ BertForSequenceClassification transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedBertForSequenceClassification(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"Classifier Output --> IPU {last_ipu}") self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self @register(BertForMultipleChoice) class PipelinedBertForMultipleChoice(BertForMultipleChoice, BertPipelineMixin): """ BertForMultipleChoice transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedBertForMultipleChoice(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"Classifier Output --> IPU {last_ipu}") self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self @register(BertForTokenClassification) class PipelinedBertForTokenClassification(BertForTokenClassification, BertPipelineMixin): """ BertForTokenClassification transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedBertForTokenClassification(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"Classifier Output --> IPU {last_ipu}") self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self @register(BertForQuestionAnswering) class PipelinedBertForQuestionAnswering(BertForQuestionAnswering, BertPipelineMixin): """ BertForQuestionAnswering transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedBertForQuestionAnswering(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"QA Outputs --> IPU {last_ipu}") self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output = super().forward( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, start_positions=start_positions, end_positions=end_positions, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if start_positions is not None and end_positions is not None: output = (poptorch.identity_loss(output[0], reduction="none"),) + output[1:] return output