def parallelize()

in optimum/graphcore/models/mt5/modeling_mt5.py [0:0]


    def parallelize(self, for_generation=False):
        """
        Transform the model to run in an IPU pipeline.
        - Adds pipeline stages to the model
        - (If enabled) Replaces the shared embedding with a SerializedEmbedding
        - Adds recomputation checkpoints

        Recommended usage:
        ```
        model = PipelinedMT5ForConditionalGeneration(config).parallelize().half()
        ```
        """
        PipelineMixin.parallelize(self)

        serialized_projection_splits_per_ipu = self.ipu_config._serialized_projection_splits_per_ipu
        projection_serialization_factor = (
            self.ipu_config._projection_serialization_factor
            if self._ipu_config._projection_serialization_factor
            else sum(serialized_projection_splits_per_ipu)
        )
        serialized_embedding_splits_per_ipu = self.ipu_config._serialized_embedding_splits_per_ipu
        embedding_serialization_factor = (
            self.ipu_config._embedding_serialization_factor
            if self.ipu_config._embedding_serialization_factor
            else sum(self.ipu_config._serialized_embedding_splits_per_ipu)
        )

        # Cannot shard input and output embeddings when using
        # tied weights. Using `SerializedLinear` is exempt since
        # the weights are not sharded
        if self.config.tie_word_embeddings and (
            embedding_serialization_factor > 1 or serialized_projection_splits_per_ipu is not None
        ):
            serialized_projection_splits_per_ipu_mode_str = self.ipu_config._get_managed_attr_mode_name(
                "serialized_projection_splits_per_ipu"
            )
            serialized_embedding_splits_per_ipu_mode_str = self.ipu_config._get_managed_attr_mode_name(
                "serialized_embedding_splits_per_ipu"
            )
            embedding_serialization_factor_mode_str = self.ipu_config._get_managed_attr_mode_name(
                "embedding_serialization_factor"
            )
            raise ValueError(
                "Cannot shard input and output embedding layers when using tied weights."
                f" {serialized_projection_splits_per_ipu_mode_str}={serialized_projection_splits_per_ipu}"
                f" {serialized_embedding_splits_per_ipu_mode_str}={serialized_embedding_splits_per_ipu}"
                " should not be provided when using tied input and output embeddings as it is"
                " redundant to split layers that can only reside on 1 IPU."
                f" {embedding_serialization_factor_mode_str}={embedding_serialization_factor}"
                " should also be set to 1 as creating a `SerializedEmbedding` will split the"
                " embedding table into sub embedding tables."
            )

        logger.info("-------------------- Device Allocation --------------------")

        if embedding_serialization_factor > 1:
            self.shared = SerializedEmbedding.from_model(self.shared, embedding_serialization_factor)
            self.encoder.embed_tokens = self.shared
            self.decoder.embed_tokens = self.shared

        if projection_serialization_factor > 1:
            if serialized_projection_splits_per_ipu is None:
                self.lm_head = SerializedLinear.from_model(self.lm_head, projection_serialization_factor)
                if self.config.tie_word_embeddings:
                    self.tie_weights()
            else:
                self.lm_head = SplitProjection.from_model(
                    self.lm_head, serialization_factor=projection_serialization_factor
                )

        self.encoder_and_decoder_embeddings_computation(True)

        # Parallelize the embedding layer
        if embedding_serialization_factor > 1 and serialized_embedding_splits_per_ipu is not None:
            # Sharing encoder and decoder computation wraps the
            # SerializedEmbedding using SharedEmbedding
            logger.info("Embedding Placement: ")
            self.shared.shared = self.shared.shared.parallelize(serialized_embedding_splits_per_ipu)
        else:
            logger.info("Embedding  --> IPU 0")
            self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0)

        # Use a custom MT5Stack implementation because sharing the position bias causes OOM error
        self.encoder.__class__ = CustomMT5Stack
        self.decoder.__class__ = CustomMT5Stack

        # Upcast input embeddings so that the residuals remain in FP32. This
        # cast is reversed where necessary by the MT5LayerNorm layers in:
        # - first layer of MT5LayerSelfAttention
        # - first layer of MT5LayerFF
        # - final_layer_norm
        # Which, conveniently, are all the places that this needs to happen.
        # Therefore, so we just need to upcast immediately before the residual
        # adds in MT5LayerSelfAttention and MT5LayerFF. This is handled in the
        # for loop below.
        self.encoder.embed_tokens = UpCastWrapper(self.encoder.embed_tokens)

        # Use a custom MT5Block implementation that removes a dynamic if blocks that can't be statically traced
        for block in self.encoder.block:
            block.__class__ = CustomMT5Block
            # Dropout happens immediately before the residual add. Inserting a
            # cast in MT5LayerSelfAttention and MT5LayerFF keeps the residual
            # structure in FP32
            block.layer[0].dropout = UpCastWrapper(block.layer[0].dropout)
            # Scale down the weights for the MT5LayerFF down-projection and
            # then scale its output back up again after it is cast to FP32
            scale = 8.0
            with torch.no_grad():
                block.layer[1].DenseReluDense.wo.weight /= scale
            block.layer[1].dropout = UpCastWrapper(block.layer[1].dropout, scale)
            # Prevent overflow in NewGELUActivation
            if self.config.dense_act_fn == "gelu_new":
                # TODO: Work-around bug with torch.nn.GELU(approximate="tanh"). Replace
                # this with block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh")
                # when bug is fixed
                block.layer[1].DenseReluDense.act = CustomGELU()

        for block in self.decoder.block:
            block.__class__ = CustomMT5Block
            # Work-around bug with torch.nn.GELU(approximate="tanh")
            # TODO: Remove this when bug is fixed
            if self.config.dense_act_fn == "gelu_new":
                block.layer[2].DenseReluDense.act = CustomGELU()

        num_encoder_layers = len(self.encoder.block)
        num_decoder_layers = len(self.decoder.block)

        if for_generation:
            # If running for text generation we split the IPU config into two configs
            # because we run the encoder and decoder as separate Poplar executors.
            ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
            self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
            encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
            decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
        else:
            number_of_layers = num_encoder_layers + num_decoder_layers
            layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
            encoder_layer_ipu = layer_ipu[:num_encoder_layers]
            decoder_layer_ipu = layer_ipu[num_encoder_layers:]

        for index, (layer, ipu) in enumerate(zip(self.encoder.block, encoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
            logger.info(f"Encoder {index:<2} --> IPU {ipu}")

        self.encoder.final_layer_norm = poptorch.BeginBlock(
            self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu
        )

        for index, (layer, ipu) in enumerate(zip(self.decoder.block, decoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
            logger.info(f"Decoder {index:<2} --> IPU {ipu}")

        self.decoder.final_layer_norm = poptorch.BeginBlock(
            self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu
        )

        # Parallelize the lm head
        if self.config.tie_word_embeddings:
            # Place LM head on IPU 0
            ipu_id = 0
            logger.info(f"LM Head Output --> IPU {ipu_id}")
            self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=ipu_id)
        else:
            # Place LM head on the last IPU if serialized_projection_splits_per_ipu is not provided
            # For generation: override serialized_projection_splits_per_ipu
            ipu_id = self.ipu_config._ipus_per_replica - 1
            if for_generation:
                serialized_projection_splits_per_ipu = self.decoder_ipu_config._serialized_projection_splits_per_ipu
                ipu_id = self.decoder_ipu_config._ipus_per_replica - 1

            # Parallelize `SplitLinear` layer if configuration is provided
            if self.lm_head.__class__ == SplitProjection:
                logger.info("LM Head Placement: ")
                self.lm_head = self.lm_head.parallelize(serialized_projection_splits_per_ipu)
            else:
                # Place SerializedLinear and nn.Linear forms of the lm head on the last IPU
                logger.info(f"LM Head Output --> IPU {ipu_id}")
                self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=ipu_id)

        self.change_lm_head_to_indexed_input_linear(restore=not for_generation)

        logger.info("-----------------------------------------------------------")
        return self