optimum/graphcore/models/groupbert/modeling_groupbert.py (572 lines of code) (raw):
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import poptorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import truncnorm
from transformers import (
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
MaskedLMOutput,
QuestionAnsweringModelOutput,
)
from transformers.modeling_utils import apply_chunking_to_forward
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertModel
from optimum.utils import logging
from ...modeling_utils import (
OnehotGather,
PipelineMixin,
SerializedEmbedding,
SerializedLinear,
get_layer_ipu,
outline_attribute,
recomputation_checkpoint,
register,
)
from .groupbert_attention import GroupBertAttention
from .groupbert_convolution import GroupBertConvolution
from .groupbert_ffn import GroupBertIntermediate, GroupBertOutput
logger = logging.get_logger(__name__)
class GroupBertConfig(BertConfig):
r"""
This is the configuration class to store the configuration of a [`GroupBertModel`]. It is used to
instantiate a GroupBERT model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`BertConfig`] and can be used to control the model outputs. Read the
documentation from [`BertConfig`] for more information.
Args:
ffn_groups (`int`, *optional*, defaults to 4):
Number of groups on the down projection of FFN
conv_group_size (`int`, *optional*, defaults to 16):
Group size for the convolution operation in the dedicated convolution modele in GroupBERT
conv_kernel_size (`int`, *optional*, defaults to 7):
Kernel size for the convolution operation in the dedicated convolution modele in GroupBERT
Examples:
```python
>>> from optimum.graphcore import GroupBertModel, GroupBertConfig
>>> # Initializing a GroupBERT configuration
>>> configuration = GroupBertConfig()
>>> # Initializing a model
>>> model = GroupBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "groupbert"
def __init__(self, ffn_groups=4, conv_group_size=16, conv_kernel_size=7, **kwargs):
super().__init__(**kwargs)
self.ffn_groups = ffn_groups
self.conv_group_size = conv_group_size
self.conv_kernel_size = conv_kernel_size
class GroupBertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.convolution = GroupBertConvolution(config)
self.intermediate_first = GroupBertIntermediate(config)
self.output_first = GroupBertOutput(config)
self.attention = GroupBertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = GroupBertAttention(config, position_embedding_type="absolute")
self.intermediate_second = GroupBertIntermediate(config)
self.output_second = GroupBertOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
convolution_output = self.convolution(hidden_states, attention_mask)
first_layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_first, self.chunk_size_feed_forward, self.seq_len_dim, convolution_output
)
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
first_layer_output,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_second, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk_first(self, conv_output):
intermediate_output = self.intermediate_first(conv_output)
layer_output = self.output_first(intermediate_output, conv_output)
return layer_output
def feed_forward_chunk_second(self, attention_output):
intermediate_output = self.intermediate_second(attention_output)
layer_output = self.output_second(intermediate_output, attention_output)
return layer_output
class GroupBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([GroupBertLayer(config) for _ in range(config.num_hidden_layers)])
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
hidden_states = self.LayerNorm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class GroupBertModel(BertModel):
config_class = GroupBertConfig
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.encoder = GroupBertEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
def truncated_normal_(tensor, mean=0, std=1):
"""
Truncated Normal distribution, truncated at 2 sigma
"""
r = torch.tensor(truncnorm.rvs(-2, 2, loc=mean, scale=std, size=tensor.shape))
tensor.data.copy_(r)
if isinstance(module, nn.Linear):
truncated_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
truncated_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GroupBertForPreTraining(BertForPreTraining):
config_class = GroupBertConfig
def __init__(self, config):
super().__init__(config)
self.bert = GroupBertModel(config)
# Initialize weights and apply final processing
self.post_init()
class GroupBertForMaskedLM(BertForMaskedLM):
config_class = GroupBertConfig
def __init__(self, config):
super().__init__(config)
self.bert = GroupBertModel(config)
# Initialize weights and apply final processing
self.post_init()
class GroupBertForSequenceClassification(BertForSequenceClassification):
config_class = GroupBertConfig
def __init__(self, config):
super().__init__(config)
self.bert = GroupBertModel(config)
# Initialize weights and apply final processing
self.post_init()
class GroupBertForMultipleChoice(BertForMultipleChoice):
config_class = GroupBertConfig
def __init__(self, config):
super().__init__(config)
self.bert = GroupBertModel(config)
# Initialize weights and apply final processing
self.post_init()
class GroupBertForTokenClassification(BertForTokenClassification):
config_class = GroupBertConfig
def __init__(self, config):
super().__init__(config)
self.bert = GroupBertModel(config)
# Initialize weights and apply final processing
self.post_init()
class GroupBertForQuestionAnswering(BertForQuestionAnswering):
config_class = GroupBertConfig
def __init__(self, config):
super().__init__(config)
self.bert = GroupBertModel(config)
# Initialize weights and apply final processing
self.post_init()
@register(GroupBertForPreTraining)
class PipelinedGroupBertForPreTraining(GroupBertForPreTraining, PipelineMixin):
"""
GroupBertForPreTraining transformed to run in an IPU pipeline.
Recommended usage:
```
model = PipelinedGroupBertForPreTraining(config).parallelize().half().train()
```
"""
def __init__(self, config):
super().__init__(config)
self.gather_indices = OnehotGather()
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the word embedding projection with a SerializedLinear layer
- Adds recomputation checkpoints
"""
super().parallelize()
if self.ipu_config.embedding_serialization_factor > 1:
self.cls.predictions.decoder = SerializedLinear.from_model(
self.cls.predictions.decoder, self.ipu_config.embedding_serialization_factor
)
self.tie_weights()
layer_ipu = get_layer_ipu(self.ipu_config, self.bert.encoder.layer)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0)
# Preventing the embeddings.LayerNorm from being outlined with the encoder.layer.LayerNorm
# improves the tile mapping of the pipeline stashes
hs = outline_attribute(self.bert.embeddings.LayerNorm, "embeddings")
self._hooks.extend(hs)
for index, layer in enumerate(self.bert.encoder.layer):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
logger.info("Pooler --> IPU 0")
self.bert.pooler = poptorch.BeginBlock(self.bert.pooler, "Pooler", ipu_id=0)
logger.info("Classifier --> IPU 0")
self.cls = poptorch.BeginBlock(self.cls, "Classifier", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
compatible with the original model.
"""
super().deparallelize()
if isinstance(self.cls.predictions.decoder, SerializedLinear):
self.cls.predictions.decoder = self.cls.predictions.decoder.to_model()
self.tie_weights()
return self
def _init_weights(self, module):
"""Initialize the weights"""
def truncated_normal_(tensor, mean=0, std=1):
"""
Truncated Normal distribution, truncated at 2 sigma
"""
r = torch.tensor(truncnorm.rvs(-2, 2, loc=mean, scale=std, size=tensor.shape))
tensor.data.copy_(r)
if isinstance(module, nn.Linear):
truncated_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
truncated_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
next_sentence_label: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
if labels is not None:
if hasattr(self.config, "max_num_masked_tokens"):
# Select only the masked tokens for the classifier
labels, positions = torch.topk(labels, k=self.config.max_num_masked_tokens, dim=1)
sequence_output = self.gather_indices(sequence_output, positions)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
masked_lm_loss = F.cross_entropy(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
).float()
next_sentence_loss = F.cross_entropy(
seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)
).float()
total_loss = poptorch.identity_loss(masked_lm_loss + next_sentence_loss, reduction="none")
# If labels are provided (training mode) only output the loss
if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return (total_loss,) if total_loss is not None else output
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores if total_loss is None else None,
seq_relationship_logits=seq_relationship_score if total_loss is None else None,
hidden_states=outputs.hidden_states if total_loss is None else None,
attentions=outputs.attentions if total_loss is None else None,
)
@register(GroupBertForMaskedLM)
class PipelinedGroupBertForMaskedLM(GroupBertForMaskedLM, PipelineMixin):
"""
GroupBertForMaskedLM transformed to run in an IPU pipeline.
Recommended usage:
```
model = PipelinedGroupBertForMaskedLM(config).parallelize().half().train()
```
"""
def __init__(self, config):
super().__init__(config)
self.gather_indices = OnehotGather()
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the word embedding projection with a SerializedLinear layer
- Adds recomputation checkpoints
"""
super().parallelize()
if self.ipu_config.embedding_serialization_factor > 1:
serialized_decoder = SerializedLinear(
self.config.hidden_size,
self.config.vocab_size,
self.ipu_config.embedding_serialization_factor,
bias=True,
mode=poptorch.MatMulSerializationMode.OutputChannels,
)
serialized_decoder.load_state_dict(self.cls.predictions.decoder.state_dict())
self.cls.predictions.decoder = serialized_decoder
self.tie_weights()
layer_ipu = get_layer_ipu(self.ipu_config, self.bert.encoder.layer)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0)
# Preventing the embeddings.LayerNorm from being outlined with the encoder.layer.LayerNorm
# improves the tile mapping of the pipeline stashes
hs = outline_attribute(self.bert.embeddings.LayerNorm, "embeddings")
self._hooks.extend(hs)
for index, layer in enumerate(self.bert.encoder.layer):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
logger.info("Classifier --> IPU 0")
self.cls = poptorch.BeginBlock(self.cls, "Classifier", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
compatible with the original model.
"""
super().deparallelize()
if self.ipu_config.embedding_serialization_factor > 1:
decoder = nn.Linear(
self.config.hidden_size,
self.config.vocab_size,
bias=True,
)
decoder.load_state_dict(self.cls.predictions.decoder.state_dict())
self.cls.predictions.decoder = decoder
self.tie_weights()
return self
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training:
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
if hasattr(self.config, "max_num_masked_tokens"):
# Select only the masked tokens for the classifier
labels, positions = torch.topk(labels, k=self.config.max_num_masked_tokens, dim=1)
sequence_output = self.gather_indices(sequence_output, positions)
prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:]
masked_lm_loss = F.cross_entropy(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
).float()
# When training only return the loss
if return_dict:
return MaskedLMOutput(loss=masked_lm_loss)
else:
return (masked_lm_loss,)
else:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
labels=labels,
return_dict=return_dict,
)
class BertPipelineMixin(PipelineMixin):
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the word embedding with a SerializedEmbedding
- Adds recomputation checkpoints
"""
super().parallelize()
layer_ipu = get_layer_ipu(self.ipu_config, self.bert.encoder.layer)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
if self.ipu_config.embedding_serialization_factor > 1:
self.bert.embeddings.word_embeddings = SerializedEmbedding.from_model(
self.bert.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor
)
self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0)
hs = outline_attribute(self.bert.embeddings.LayerNorm, "embedding")
self._hooks.extend(hs)
for index, layer in enumerate(self.bert.encoder.layer):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
compatible with the original model.
"""
super().deparallelize()
# Deserialize the serialized word embedding
if self.ipu_config.embedding_serialization_factor > 1:
self.bert.embeddings.word_embeddings = self.bert.embeddings.word_embeddings.to_model()
return self
@register(GroupBertForSequenceClassification)
class PipelinedGroupBertForSequenceClassification(GroupBertForSequenceClassification, BertPipelineMixin):
"""
GroupBertForSequenceClassification transformed to run in an IPU pipeline.
Recommended usage:
```
model = PipelinedGroupBertForSequenceClassification(config).parallelize().half()
```
"""
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Classifier Output --> IPU {last_ipu}")
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
@register(GroupBertForMultipleChoice)
class PipelinedGroupBertForMultipleChoice(GroupBertForMultipleChoice, BertPipelineMixin):
"""
GroupBertForMultipleChoice transformed to run in an IPU pipeline.
Recommended usage:
```
model = PipelinedGroupBertForMultipleChoice(config).parallelize().half()
```
"""
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Classifier Output --> IPU {last_ipu}")
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
@register(GroupBertForTokenClassification)
class PipelinedGroupBertForTokenClassification(GroupBertForTokenClassification, BertPipelineMixin):
"""
GroupBertForTokenClassification transformed to run in an IPU pipeline.
Recommended usage:
```
model = PipelinedGroupBertForTokenClassification(config).parallelize().half()
```
"""
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Classifier Output --> IPU {last_ipu}")
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
@register(GroupBertForQuestionAnswering)
class PipelinedGroupBertForQuestionAnswering(GroupBertForQuestionAnswering, BertPipelineMixin):
"""
GroupBertForQuestionAnswering transformed to run in an IPU pipeline.
Recommended usage:
```
model = PipelinedGroupBertForQuestionAnswering(config).parallelize().half()
```
"""
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"QA Outputs --> IPU {last_ipu}")
self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output = super().forward(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if start_positions is not None and end_positions is not None:
output = (poptorch.identity_loss(output[0], reduction="none"),) + output[1:]
return output