def setup()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def setup(self) -> None:
        self.conv1 = Conv(
            self.config.d_model,
            kernel_size=(3,),
            padding=1,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            kernel_axes=("channels", "num_mel", "embed"),
        )
        self.conv2 = Conv(
            self.config.d_model,
            kernel_size=(3,),
            strides=2,
            padding=1,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            kernel_axes=("channels", "embed", "num_mel"),
        )

        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        self.layers = FlaxWhisperEncoderLayerCollection(
            self.config,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            use_scan=self.use_scan,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        self.embed_positions = Embed(
            self.config.max_source_positions,
            self.config.d_model,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
        )

        self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)