def parallelize()

in optimum/graphcore/models/bart/modeling_bart.py [0:0]


    def parallelize(self, for_generation=False, use_cache=False, **kwargs):
        """
        Transform the model to run in an IPU pipeline.
        - Adds pipeline stages to the model
        - (If enabled) Replaces the shared embedding with a SerializedEmbedding
        - Adds recomputation checkpoints

        Recommended usage:
        ```
        model = PipelinedBartForConditionalGeneration(config).parallelize().half()
        ```
        """
        super().parallelize()

        if use_cache:
            kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)

        logger.info("-------------------- Device Allocation --------------------")
        logger.info("Embedding  --> IPU 0")

        if self.ipu_config.embedding_serialization_factor > 1:
            self.lm_head = SerializedLinear.from_model(self.lm_head, self.ipu_config.embedding_serialization_factor)
            self.tie_weights()

        self.model.__class__ = _BartModelWithSharedEmbedding
        self.model.encoder_and_decoder_embeddings_computation(use_shared_embedding=True)
        self.model.change_bart_encoder_and_decoder_classes(restore=False)
        self.model.change_bart_attention_class(restore=False, use_cache=use_cache and for_generation, **kwargs)
        self.model.change_decoder_positional_embedding(restore=False)
        self.change_lm_head_to_indexed_input_linear(restore=not (for_generation and not use_cache))
        self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
        self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))
        self.model.quantize_linear_layers(restore=not kwargs.get("use_group_quantized_linears", False), num_groups=16)

        self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0)
        self.model.encoder.embed_positions = poptorch.BeginBlock(
            self.model.encoder.embed_positions, "Embedding", ipu_id=0
        )
        self.model.encoder.layernorm_embedding = poptorch.BeginBlock(
            self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0
        )

        num_encoder_layers = len(self.model.encoder.layers)
        num_decoder_layers = len(self.model.decoder.layers)

        if for_generation:
            # If running for text generation we split the IPU config into two configs
            # because we run the encoder and decoder as separate Poplar executors.
            ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
            self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
            encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
            decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
        else:
            number_of_layers = num_encoder_layers + num_decoder_layers
            layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
            encoder_layer_ipu = layer_ipu[:num_encoder_layers]
            decoder_layer_ipu = layer_ipu[num_encoder_layers:]

        for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
            logger.info(f"Encoder {index:<2} --> IPU {ipu}")

        self.model.decoder.embed_positions = poptorch.BeginBlock(
            self.model.decoder.embed_positions, "Embedding", ipu_id=0
        )
        self.model.decoder.layernorm_embedding = poptorch.BeginBlock(
            self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0
        )

        for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
            logger.info(f"Decoder {index:<2} --> IPU {ipu}")

        logger.info("LM Head Output --> IPU 0")
        self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0)
        logger.info("-----------------------------------------------------------")
        return self