def setup()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def setup(self) -> None:
        self.embed_dim = self.config.d_model
        self.self_attn = FlaxWhisperAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
        )
        self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        self.activation_fn = ACT2FN[self.config.activation_function]
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
        self.fc1 = DenseGeneral(
            self.config.encoder_ffn_dim,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            kernel_axes=("embed", "mlp"),
        )
        self.fc2 = DenseGeneral(
            self.embed_dim,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            kernel_axes=("mlp", "embed"),
        )
        self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)