in optimum/graphcore/models/gpt2/modeling_gpt2.py [0:0]
def parallelize(self, for_generation=False):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedGPT2LMHeadModel(config).parallelize().half()
```
"""
PipelineMixin.parallelize(self)
# Use optimized attention
for layer in self.transformer.h:
layer.attn.__class__ = OptimizedGPT2Attention
if self.ipu_config.embedding_serialization_factor > 1:
# Resize token embedding using padding if vocab_size is not a multiple of embedding_serialization_factor.
self.actual_vocab_size = self.config.vocab_size
new_vocab_size = (
math.ceil(self.config.vocab_size / self.ipu_config.embedding_serialization_factor)
* self.ipu_config.embedding_serialization_factor
)
if new_vocab_size > self.actual_vocab_size:
# There is a tie_weights operation in resize_token_embeddings so the lm_head's weight is also resized.
self.resize_token_embeddings(new_vocab_size)
self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor)
self.tie_weights()
self.change_lm_head_to_indexed_input_linear(restore=not for_generation)
logger.info("-------------------- Device Allocation --------------------")
logger.info("Token Embedding --> IPU 0")
self.transformer.wte = poptorch.BeginBlock(self.transformer.wte, "Token embedding", ipu_id=0)
logger.info("Position Embedding --> IPU 0")
self.transformer.wpe = poptorch.BeginBlock(self.transformer.wpe, "Position embedding", ipu_id=0)
hs = outline_attribute(self.transformer.ln_f, "LayerNorm")
self._hooks.extend(hs)
layer_ipu = get_layer_ipu(self.ipu_config, self.transformer.h)
for index, layer in enumerate(self.transformer.h):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.transformer.h[index] = poptorch.BeginBlock(layer, f"Layer{index}", ipu_id=ipu)
logger.info(f"Layer {index:<2} --> IPU {ipu}")
logger.info("Head --> IPU 0")
self.lm_head = poptorch.BeginBlock(self.lm_head, "LM head", ipu_id=0)
logger.info("-----------------------------------------------------------")
return self