in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
self.decoder = FlaxWhisperDecoder(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)