in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def setup(self) -> None:
self.conv1 = Conv(
self.config.d_model,
kernel_size=(3,),
padding=1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("channels", "num_mel", "embed"),
)
self.conv2 = Conv(
self.config.d_model,
kernel_size=(3,),
strides=2,
padding=1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("channels", "embed", "num_mel"),
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
params_dtype=self.params_dtype,
use_scan=self.use_scan,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = Embed(
self.config.max_source_positions,
self.config.d_model,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)