def convert_unroll_to_scan()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
        r"""
        Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
        to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
        convert the `params` in place.

        To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
        projection params is as follows:
            ('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
            'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
        This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
        matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
            ('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')

        When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
        _do_init=False, it will have to be called explicitly (see example below).

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.

        Examples:

        ```python
        >>> from distil_whisper import FlaxWhisperForConditionalGeneration

        >>> # Download model and configuration from huggingface.co
        >>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
        >>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
        >>> # we'll first convert to scan format and then back to unrolled
        >>> model.enable_scan()
        >>> params = model.convert_unroll_to_scan(params)
        >>> # now convert back to unrolled
        >>> model.disable_scan()
        >>> params = model.convert_scan_to_unroll(params)
        ```"""
        if isinstance(params, FrozenDict):
            params = unfreeze(params)

        params = flatten_dict(params, sep="/")
        keys = list(params.keys())

        for k in keys:
            # Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
            # These params contain the identifier `layer` in their key
            if "layers/0" in k:
                if "decoder" in k:
                    block_prefix = "Decoder"
                    num_hidden_layers = self.config.decoder_layers
                else:
                    block_prefix = "Encoder"
                    num_hidden_layers = self.config.encoder_layers

                # Squash the keys for the N unrolled layers into one single key:
                # (layer/0, ..., layer/N) -> layer/FlaxScanLayers
                scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
                stacked_params = []

                # Iterate over the unrolled layers (1,...,N)
                for i in range(num_hidden_layers):
                    # Stack the params for the N layers into one super block
                    # and remove the unrolled layer params on the fly
                    # -> no memory overhead for conversion!
                    unrolled_layer = params.pop(k.replace("0", str(i)))
                    stacked_params.append(unrolled_layer)

                params[scan_key] = jnp.stack(stacked_params)

        # Finally, unflatten the dict to restore the nested pytree structure
        params = unflatten_dict(params, sep="/")
        return params