optimum/graphcore/models/distilbert/modeling_distilbert.py (229 lines of code) (raw):
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import poptorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
)
from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention
from optimum.utils import logging
from ...modeling_utils import (
OnehotGather,
PipelineMixin,
SerializedEmbedding,
SerializedLinear,
get_layer_ipu,
recomputation_checkpoint,
register,
)
logger = logging.get_logger(__name__)
class IPUMultiHeadSelfAttention(MultiHeadSelfAttention):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
mask_reshp = (bs, 1, 1, k_length)
def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = mask.to(dtype=scores.dtype) # fp16 compatibility
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
mask = (1.0 - mask) * -10000.0
mask = mask.view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores + mask # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
if output_attentions:
return (context, weights)
else:
return (context,)
class DistilBertPipelineMixin(PipelineMixin):
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Adds recomputation checkpoints
"""
super().parallelize()
for layer in self.distilbert.transformer.layer:
layer.attention.__class__ = IPUMultiHeadSelfAttention
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
is_masked_lm = isinstance(self, DistilBertForMaskedLM)
if self.ipu_config.embedding_serialization_factor > 1 and not is_masked_lm:
self.distilbert.embeddings.word_embeddings = SerializedEmbedding.from_model(
self.distilbert.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor
)
self.distilbert.embeddings = poptorch.BeginBlock(self.distilbert.embeddings, "Embedding", 0)
layer_ipu = get_layer_ipu(self.ipu_config, self.distilbert.transformer.layer)
for index, layer in enumerate(self.distilbert.transformer.layer):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.distilbert.transformer.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
compatible with the original model.
"""
super().deparallelize()
for layer in self.distilbert.transformer.layer:
layer.attention.__class__ = MultiHeadSelfAttention
is_masked_lm = isinstance(self, DistilBertForMaskedLM)
if self.ipu_config.embedding_serialization_factor > 1 and not is_masked_lm:
self.distilbert.embeddings.word_embeddings = self.distilbert.embeddings.word_embeddings.to_model()
return self
@register(DistilBertForMaskedLM)
class PipelinedDistilBertForMaskedLM(DistilBertForMaskedLM, DistilBertPipelineMixin):
def __init__(self, config):
super().__init__(config)
self.gather_indices = OnehotGather()
def parallelize(self):
super().parallelize()
if self.ipu_config.embedding_serialization_factor > 1:
self.vocab_projector = SerializedLinear.from_model(
self.vocab_projector, self.ipu_config.embedding_serialization_factor
)
self.tie_weights()
logger.info("LM Head --> IPU 0")
self.vocab_transform = poptorch.BeginBlock(self.vocab_transform, "LM Head", ipu_id=0)
self.vocab_layer_norm = poptorch.BeginBlock(self.vocab_layer_norm, "LM Head", ipu_id=0)
self.vocab_projector = poptorch.BeginBlock(self.vocab_projector, "LM Head", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
super().deparallelize()
if isinstance(self.vocab_projector, SerializedLinear):
self.vocab_projector = self.vocab_projector.to_model()
self.tie_weights()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training:
dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
if hasattr(self.config, "max_num_masked_tokens"):
# Select only the masked tokens for the classifier
labels, positions = torch.topk(labels, k=self.config.max_num_masked_tokens, dim=1)
hidden_states = self.gather_indices(hidden_states, positions)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
masked_lm_loss = F.cross_entropy(prediction_logits.view(-1, self.config.vocab_size), labels.view(-1))
# When training only return the loss
if return_dict:
return MaskedLMOutput(loss=masked_lm_loss)
else:
return (masked_lm_loss,)
else:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@register(DistilBertForSequenceClassification)
class PipelinedDistilBertForSequenceClassification(DistilBertForSequenceClassification, DistilBertPipelineMixin):
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Classifier --> IPU {last_ipu}")
self.pre_classifier = poptorch.BeginBlock(self.pre_classifier, "Classifier", ipu_id=last_ipu)
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
@register(DistilBertForQuestionAnswering)
class PipelinedDistilBertForQuestionAnswering(DistilBertForQuestionAnswering, DistilBertPipelineMixin):
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"QA Outputs --> IPU {last_ipu}")
self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor, ...]]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if start_positions is not None and end_positions is not None:
output = (poptorch.identity_loss(output[0], reduction="none"),) + output[1:]
return output
@register(DistilBertForTokenClassification)
class PipelinedDistilBertForTokenClassification(DistilBertForTokenClassification, DistilBertPipelineMixin):
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Classifier --> IPU {last_ipu}")
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
@register(DistilBertForMultipleChoice)
class PipelinedDistilBertForMultipleChoice(DistilBertForMultipleChoice, DistilBertPipelineMixin):
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Classifier --> IPU {last_ipu}")
self.pre_classifier = poptorch.BeginBlock(self.pre_classifier, "Classifier", ipu_id=last_ipu)
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self