optimum/graphcore/models/gpt2/modeling_gpt2.py (232 lines of code) (raw):
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import poptorch
import torch
import torch.nn as nn
from transformers import GPT2ForSequenceClassification, GPT2ForTokenClassification, GPT2LMHeadModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from optimum.utils import logging
from ...generation import IPUGenerationMixin
from ...modeling_utils import (
PipelineMixin,
SerializedEmbedding,
SerializedLinear,
get_layer_ipu,
outline_attribute,
recomputation_checkpoint,
register,
)
from .optimized_gpt2_attn import OptimizedGPT2Attention
logger = logging.get_logger(__name__)
class GPT2PipelineMixin(PipelineMixin):
def parallelize(self):
"""
Transform the GPT2 model body to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the word embedding with a SerializedEmbedding
- Adds recomputation checkpoints
"""
super().parallelize()
# Use optimized attention
for layer in self.transformer.h:
layer.attn.__class__ = OptimizedGPT2Attention
if self.ipu_config.embedding_serialization_factor > 1:
# Resize token embedding using padding if vocab_size is not a multiple of embedding_serialization_factor.
self.actual_vocab_size = self.config.vocab_size
new_vocab_size = (
math.ceil(self.config.vocab_size / self.ipu_config.embedding_serialization_factor)
* self.ipu_config.embedding_serialization_factor
)
if new_vocab_size > self.actual_vocab_size:
self.resize_token_embeddings(new_vocab_size)
self.transformer.wte = SerializedEmbedding.from_model(
self.transformer.wte, self.ipu_config.embedding_serialization_factor
)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
self.transformer.wte = poptorch.BeginBlock(self.transformer.wte, "Token embedding", ipu_id=0)
self.transformer.wpe = poptorch.BeginBlock(self.transformer.wpe, "Position embedding", ipu_id=0)
hs = outline_attribute(self.transformer.ln_f, "LayerNorm")
self._hooks.extend(hs)
layer_ipu = get_layer_ipu(self.ipu_config, self.transformer.h)
for index, layer in enumerate(self.transformer.h):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.transformer.h[index] = poptorch.BeginBlock(layer, f"Layer{index}", ipu_id=ipu)
logger.info(f"Layer {index:<2} --> IPU {ipu}")
return self
def deparallelize(self):
"""
Undo the changes to the model done by `parallelize`.
You should call this before doing `save_pretrained` so that the `model.state_dict` is
fully compatible with `transformers` models.
"""
super().deparallelize()
if self.ipu_config.embedding_serialization_factor > 1:
# Deserialize the serialized word embedding
self.transformer.wte = self.transformer.wte.to_model()
# Resize token embeddings back to origianl vocab_size
if self.config.vocab_size > self.actual_vocab_size:
self.resize_token_embeddings(self.actual_vocab_size)
# Switch back to non-optimized attention
for layer in self.transformer.h:
layer.attn.__class__ = GPT2Attention
return self
@register(GPT2LMHeadModel)
class PipelinedGPT2LMHeadModel(GPT2LMHeadModel, PipelineMixin, IPUGenerationMixin):
def parallelize(self, for_generation=False):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedGPT2LMHeadModel(config).parallelize().half()
```
"""
PipelineMixin.parallelize(self)
# Use optimized attention
for layer in self.transformer.h:
layer.attn.__class__ = OptimizedGPT2Attention
if self.ipu_config.embedding_serialization_factor > 1:
# Resize token embedding using padding if vocab_size is not a multiple of embedding_serialization_factor.
self.actual_vocab_size = self.config.vocab_size
new_vocab_size = (
math.ceil(self.config.vocab_size / self.ipu_config.embedding_serialization_factor)
* self.ipu_config.embedding_serialization_factor
)
if new_vocab_size > self.actual_vocab_size:
# There is a tie_weights operation in resize_token_embeddings so the lm_head's weight is also resized.
self.resize_token_embeddings(new_vocab_size)
self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor)
self.tie_weights()
self.change_lm_head_to_indexed_input_linear(restore=not for_generation)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Token Embedding --> IPU 0")
self.transformer.wte = poptorch.BeginBlock(self.transformer.wte, "Token embedding", ipu_id=0)
logger.info("Position Embedding --> IPU 0")
self.transformer.wpe = poptorch.BeginBlock(self.transformer.wpe, "Position embedding", ipu_id=0)
hs = outline_attribute(self.transformer.ln_f, "LayerNorm")
self._hooks.extend(hs)
layer_ipu = get_layer_ipu(self.ipu_config, self.transformer.h)
for index, layer in enumerate(self.transformer.h):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.transformer.h[index] = poptorch.BeginBlock(layer, f"Layer{index}", ipu_id=ipu)
logger.info(f"Layer {index:<2} --> IPU {ipu}")
logger.info("Head --> IPU 0")
self.lm_head = poptorch.BeginBlock(self.lm_head, "LM head", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self
def deparallelize(self):
PipelineMixin.deparallelize(self)
self.change_lm_head_to_indexed_input_linear(restore=True)
if isinstance(self.lm_head, SerializedLinear):
self.lm_head = self.lm_head.to_model()
self.tie_weights()
# Resize token embeddings back to origianl vocab_size.
# There is a tie_weights operation in resize_token_embeddings so the lm_head's weight is also resized.
if self.config.vocab_size > self.actual_vocab_size:
self.resize_token_embeddings(self.actual_vocab_size)
# Switch back to non-optimized attention
for layer in self.transformer.h:
layer.attn.__class__ = GPT2Attention
return self
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if self.ipu_config.embedding_serialization_factor > 1 and self.config.vocab_size > self.actual_vocab_size:
# Ignore the padding logits. Use masking because in-place modification on a slice is not supported yet.
padding_mask = torch.cat(
(
torch.ones(self.actual_vocab_size),
torch.zeros(self.config.vocab_size - self.actual_vocab_size),
)
).to(dtype=lm_logits.dtype, device=lm_logits.device)
lm_logits = lm_logits * padding_mask + (1 - padding_mask) * -10000.0
# TODO: Use the following line instead to ignore the padding logits
# lm_logits[:, :, self.actual_vocab_size:] = -10000
loss = None
if labels is not None:
# Shift so that tokens < n predict n. Use roll() + ignore_index instead of slicing for better efficiency on IPUs.
labels = torch.roll(labels, -1, 1)
# By default the ignore_index of CrossEntropyLoss is -100
labels[:, -1] = -100
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if self.ipu_config.embedding_serialization_factor > 1 and self.config.vocab_size > self.actual_vocab_size:
lm_logits = lm_logits[:, :, : self.actual_vocab_size]
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
if self.training:
# Only returning the loss to make the communication between the host and the device faster.
return (loss,)
else:
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits if not self.training else None,
past_key_values=transformer_outputs.past_key_values if not self.training else None,
hidden_states=transformer_outputs.hidden_states if not self.training else None,
attentions=transformer_outputs.attentions if not self.training else None,
cross_attentions=transformer_outputs.cross_attentions if not self.training else None,
)
@register(GPT2ForSequenceClassification)
class PipelinedGPT2ForSequenceClassification(GPT2ForSequenceClassification, GPT2PipelineMixin):
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Head --> IPU {last_ipu}")
self.score = poptorch.BeginBlock(self.score, "Score", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# By default use_cache=True and the model would return past_key_values, which could be very large and cause OOM.
# To prevent this we only return loss and logits during training and evaluation (i.e. when there are labels).
if not return_dict:
loss, logits = outputs[0], outputs[1]
return (loss, logits) if labels is not None else outputs
return SequenceClassifierOutputWithPast(
loss=outputs.loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values if labels is None else None,
hidden_states=outputs.hidden_states if labels is None else None,
attentions=outputs.attentions if labels is None else None,
)
@register(GPT2ForTokenClassification)
class PipelinedGPT2ForTokenClassification(GPT2ForTokenClassification, GPT2PipelineMixin):
def parallelize(self):
super().parallelize()
last_ipu = self.ipu_config._ipus_per_replica - 1
logger.info(f"Head --> IPU {last_ipu}")
self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu)
logger.info("-----------------------------------------------------------")
return self