def setup()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def setup(self) -> None:
        self.embed_tokens = Embed(
            self.config.vocab_size,
            self.config.d_model,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
        )
        self.embed_positions = Embed(
            self.config.max_target_positions,
            self.config.d_model,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
        )

        self.layers = FlaxWhisperDecoderLayerCollection(
            self.config,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            use_scan=self.use_scan,
            gradient_checkpointing=self.gradient_checkpointing,
        )

        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype)