in optimum/graphcore/models/t5/modeling_t5.py [0:0]
def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=False, **kwargs):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- (If enabled) Replaces the shared embedding with a SerializedEmbedding
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedT5ForConditionalGeneration(config).parallelize().half()
```
"""
PipelineMixin.parallelize(self)
if use_cache:
kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
if self.ipu_config.embedding_serialization_factor > 1:
self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor)
# TODO: is it needed to check?
if self.config.tie_word_embeddings:
self.tie_weights()
self.change_lm_head_to_indexed_input_linear(restore=not (for_generation and not use_cache))
self.encoder_and_decoder_embeddings_computation(True)
self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0)
# Use a custom T5Stack implementation because sharing the position bias causes OOM error
self.encoder.__class__ = CustomT5Stack
self.decoder.__class__ = CustomT5Stack
# Optimisations for generation
self.change_attention_class(
restore=False,
use_cache=use_cache and for_generation,
use_cross_cache=use_cross_cache and for_generation,
**kwargs,
)
self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))
# Upcast input embeddings so that the residuals remain in FP32. This
# cast is reversed where necessary by the T5LayerNorm layers in:
# - first layer of T5LayerSelfAttention
# - first layer of T5LayerFF
# - final_layer_norm
# Which, conveniently, are all the places that this needs to happen.
# Therefore, so we just need to upcast immediately before the residual
# adds in T5LayerSelfAttention and T5LayerFF. This is handled in the
# for loop below.
self.encoder.embed_tokens = UpCastWrapper(self.encoder.embed_tokens)
# Use a custom T5Block implementation that removes a dynamic if blocks that can't be statically traced
for block in self.encoder.block:
block.__class__ = CustomT5Block
# Dropout happens immediately before the residual add. Inserting a
# cast in T5LayerSelfAttention and T5LayerFF keeps the residual
# structure in FP32
block.layer[0].dropout = UpCastWrapper(block.layer[0].dropout)
# Scale down the weights for the T5LayerFF down-projection and
# then scale its output back up again after it is cast to FP32
scale = 8.0
with torch.no_grad():
block.layer[1].DenseReluDense.wo.weight /= scale
block.layer[1].dropout = UpCastWrapper(block.layer[1].dropout, scale)
# Prevent overflow in NewGELUActivation
if self.config.dense_act_fn == "gelu_new":
# TODO: Work-around bug with torch.nn.GELU(approximate="tanh"). Replace
# this with block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh")
# when bug is fixed
block.layer[1].DenseReluDense.act = CustomGELU()
for block in self.decoder.block:
block.__class__ = CustomT5Block
# Work-around bug with torch.nn.GELU(approximate="tanh")
# TODO: Remove this when bug is fixed
if self.config.dense_act_fn == "gelu_new":
block.layer[2].DenseReluDense.act = CustomGELU()
num_encoder_layers = len(self.encoder.block)
num_decoder_layers = len(self.decoder.block)
if for_generation:
# If running for text generation we split the IPU config into two configs
# because we run the encoder and decoder as separate Poplar executors.
ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
else:
number_of_layers = num_encoder_layers + num_decoder_layers
layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
encoder_layer_ipu = layer_ipu[:num_encoder_layers]
decoder_layer_ipu = layer_ipu[num_encoder_layers:]
for index, (layer, ipu) in enumerate(zip(self.encoder.block, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")
self.encoder.final_layer_norm = poptorch.BeginBlock(
self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu
)
for index, (layer, ipu) in enumerate(zip(self.decoder.block, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
self._hooks.append(recomputation_checkpoint(layer))
self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")
self.decoder.final_layer_norm = poptorch.BeginBlock(
self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu
)
logger.info("LM Head Output --> IPU 0")
self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self