def parallelize()

in optimum/graphcore/models/whisper/modeling_whisper.py [0:0]


    def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=False, **kwargs):
        super().parallelize()

        if use_cache:
            kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)

        self._use_cond_encoder = kwargs.get("use_cond_encoder", False)
        self._use_encoder_output_buffer = kwargs.get("use_encoder_output_buffer", False)
        if self._use_cond_encoder and self._use_encoder_output_buffer:
            raise ValueError(
                "`use_cond_encoder=True` is incompatible with `use_encoder_output_buffer=True`, only set one to True."
            )
        self._use_group_quantized_linears = kwargs.get("use_group_quantized_linears", False)

        self.change_encoder_layer_class(restore=False)
        self.change_decoder_class(restore=False)
        self.change_decoder_positional_embedding(restore=False)
        self.change_attention_class(
            restore=False,
            use_cache=use_cache and for_generation,
            use_cross_cache=use_cross_cache and for_generation,
            **kwargs,
        )
        self.change_lm_head(restore=False, use_cache=use_cache or not for_generation)
        self.change_encoder_class(restore=not self._use_cond_encoder, **kwargs)
        self.quantize_linear_layers(restore=not self._use_group_quantized_linears, num_groups=16)
        self.set_on_device_generation_steps(kwargs.get("on_device_generation_steps", 0))

        logger.info("---------- Device Allocation -----------")
        logger.info("conv1, conv2, embed_positions  --> IPU 0")
        self.model.encoder.conv1 = poptorch.BeginBlock(self.model.encoder.conv1, "Conv1", ipu_id=0)
        self.model.encoder.conv2 = poptorch.BeginBlock(self.model.encoder.conv2, "Conv2", ipu_id=0)
        self.model.encoder.embed_positions = poptorch.BeginBlock(
            self.model.encoder.embed_positions, "Embed Positions", ipu_id=0
        )

        num_encoder_layers = len(self.model.encoder.layers)
        num_decoder_layers = len(self.model.decoder.layers)

        if for_generation and not self._use_cond_encoder:
            # If running for text generation (and the encoder and decoder are run as separate Poplar executors)
            # we split the IPU config into two configs.
            ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
            self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
            encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
            decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
        else:
            number_of_layers = num_encoder_layers + num_decoder_layers
            layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
            encoder_layer_ipu = layer_ipu[:num_encoder_layers]
            decoder_layer_ipu = layer_ipu[num_encoder_layers:]

        for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
            logger.info(f"Encoder {index:<2} --> IPU {ipu}")

        # we need to deal with the model.encoder.layer norm
        self.model.encoder.layer_norm = poptorch.BeginBlock(
            self.model.encoder.layer_norm, "Encoder Layer Norm", ipu_id=ipu
        )
        logger.info(f"Encoder LN --> IPU {ipu}")

        decoder_embedding_ipu = decoder_layer_ipu[0]
        if (serialized_projection_splits_per_ipu := self.ipu_config._serialized_projection_splits_per_ipu) is not None:
            serialized_projection_ipus = [i for i, x in enumerate(serialized_projection_splits_per_ipu) if x]
            if len(serialized_projection_ipus) > 1:
                # This is because we are using SerializedLinear. All splits of a SerializedLinear layer must be on the
                # same IPU. We are using SerializedLinear instead of SplitLinear because we must tie the weights, which
                # cannot be done when using SplitLinear.
                raise ValueError(
                    "`serialized_projection_splits_per_ipu` must only have 1 non-zero element for Whisper."
                )
            decoder_embedding_ipu = serialized_projection_ipus[0]
        self.model.decoder.embed_tokens = poptorch.BeginBlock(
            self.model.decoder.embed_tokens, "Decoder Embedding", ipu_id=decoder_embedding_ipu
        )
        logger.info(f"Decoder Embedding  --> IPU {decoder_embedding_ipu}")

        prev_ipu = decoder_layer_ipu[0]
        for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            if ipu != prev_ipu:
                self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
                prev_ipu = ipu
            logger.info(f"Decoder {index:<2} --> IPU {ipu}")

        self.model.decoder.layer_norm = poptorch.BeginBlock(
            self.model.decoder.layer_norm, "Decoder Layer Norm", ipu_id=ipu
        )

        logger.info(f"Head       --> IPU {decoder_embedding_ipu}")
        logger.info("---------------------------------------")
        self.proj_out = poptorch.BeginBlock(self.proj_out, "Output Projection", ipu_id=decoder_embedding_ipu)
        return self