in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
r"""
Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
not convert the `params` in place.
To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')
This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
scanned matrix of parameteres on the fly:
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
_do_init=False, it will have to be called explicitly (see example below).
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
Examples:
```python
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
>>> # we'll first convert to scan format and then back to unrolled
>>> model.enable_scan()
>>> params = model.convert_unroll_to_scan(params)
>>> # now convert back to unrolled
>>> model.disable_scan()
>>> params = model.convert_scan_to_unroll(params)
```"""
if isinstance(params, FrozenDict):
params = unfreeze(params)
params = flatten_dict(params, sep="/")
keys = list(params.keys())
for k in keys:
# Identify all "scan" layers formed as part of the FlaxBertLayerCollection
# These params contain the identifier `FlaxScanLayers` in their key
if "FlaxEncoderScanLayers" in k:
# Remove the scan layer from the PyTree of params
scan_layer = params.pop(k)
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
for i in range(self.config.encoder_layers):
# Unstack the params for the i-th scan layer to unrolled
# and remove corresponding scan params on the fly
# -> no memory overhead for conversion!
unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
elif "FlaxDecoderScanLayers" in k:
# Remove the scan layer from the PyTree of params
scan_layer = params.pop(k)
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
for i in range(self.config.decoder_layers):
# Unstack the params for the i-th scan layer to unrolled
# and remove corresponding scan params on the fly
# -> no memory overhead for conversion!
unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
params = unflatten_dict(params, sep="/")
return params