def convert_scan_to_unroll()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
        r"""
        Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
        used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
        not convert the `params` in place.

        To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
        projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
            ('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')

        This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
        scanned matrix of parameteres on the fly:
            ('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
            'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')

        When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
        _do_init=False, it will have to be called explicitly (see example below).

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.

        Examples:

        ```python
        >>> from distil_whisper import FlaxWhisperForConditionalGeneration

        >>> # Download model and configuration from huggingface.co
        >>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
        >>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
        >>> # we'll first convert to scan format and then back to unrolled
        >>> model.enable_scan()
        >>> params = model.convert_unroll_to_scan(params)
        >>> # now convert back to unrolled
        >>> model.disable_scan()
        >>> params = model.convert_scan_to_unroll(params)
        ```"""

        if isinstance(params, FrozenDict):
            params = unfreeze(params)

        params = flatten_dict(params, sep="/")
        keys = list(params.keys())

        for k in keys:
            # Identify all "scan" layers formed as part of the FlaxBertLayerCollection
            # These params contain the identifier `FlaxScanLayers` in their key
            if "FlaxEncoderScanLayers" in k:
                # Remove the scan layer from the PyTree of params
                scan_layer = params.pop(k)

                # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
                # layer/FlaxScanLayers -> (layer/0, ..., layer/N)
                for i in range(self.config.encoder_layers):
                    # Unstack the params for the i-th scan layer to unrolled
                    # and remove corresponding scan params on the fly
                    # -> no memory overhead for conversion!
                    unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
                    params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]

            elif "FlaxDecoderScanLayers" in k:
                # Remove the scan layer from the PyTree of params
                scan_layer = params.pop(k)

                # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
                # layer/FlaxScanLayers -> (layer/0, ..., layer/N)
                for i in range(self.config.decoder_layers):
                    # Unstack the params for the i-th scan layer to unrolled
                    # and remove corresponding scan params on the fly
                    # -> no memory overhead for conversion!
                    unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
                    params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]

        params = unflatten_dict(params, sep="/")
        return params