in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
r"""
Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
convert the `params` in place.
To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
projection params is as follows:
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
_do_init=False, it will have to be called explicitly (see example below).
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
Examples:
```python
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
>>> # we'll first convert to scan format and then back to unrolled
>>> model.enable_scan()
>>> params = model.convert_unroll_to_scan(params)
>>> # now convert back to unrolled
>>> model.disable_scan()
>>> params = model.convert_scan_to_unroll(params)
```"""
if isinstance(params, FrozenDict):
params = unfreeze(params)
params = flatten_dict(params, sep="/")
keys = list(params.keys())
for k in keys:
# Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
# These params contain the identifier `layer` in their key
if "layers/0" in k:
if "decoder" in k:
block_prefix = "Decoder"
num_hidden_layers = self.config.decoder_layers
else:
block_prefix = "Encoder"
num_hidden_layers = self.config.encoder_layers
# Squash the keys for the N unrolled layers into one single key:
# (layer/0, ..., layer/N) -> layer/FlaxScanLayers
scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
stacked_params = []
# Iterate over the unrolled layers (1,...,N)
for i in range(num_hidden_layers):
# Stack the params for the N layers into one super block
# and remove the unrolled layer params on the fly
# -> no memory overhead for conversion!
unrolled_layer = params.pop(k.replace("0", str(i)))
stacked_params.append(unrolled_layer)
params[scan_key] = jnp.stack(stacked_params)
# Finally, unflatten the dict to restore the nested pytree structure
params = unflatten_dict(params, sep="/")
return params